Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.38"
version = "0.9.39"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
61 changes: 27 additions & 34 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,65 +13,58 @@ Notice:
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
==#
Base.:+(x::NotImplemented) = throw(NotImplementedException(x))
Base.:+(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
Base.:-(x::NotImplemented) = throw(NotImplementedException(x))
Base.:*(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
function LinearAlgebra.dot(x::NotImplemented, ::NotImplemented)
return throw(NotImplementedException(x))
end
for T in (:DoesNotExist, :One, :AbstractThunk, :Composite, :Any)
@eval Base.:+(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@eval Base.:+(::$T, x::NotImplemented) = throw(NotImplementedException(x))
@eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@eval Base.:-(::$T, x::NotImplemented) = throw(NotImplementedException(x))

@eval Base.:*(::$T, x::NotImplemented) = throw(NotImplementedException(x))

@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@eval LinearAlgebra.dot(::$T, x::NotImplemented) = throw(NotImplementedException(x))

# required for `@scalar_rule`
# we propagate `NotImplemented` (e.g., in `@scalar_rule`)
# this requires the following definitions (see also #337)
Base.:+(x::NotImplemented, ::Zero) = x
Base.:+(::Zero, x::NotImplemented) = x
Base.:+(x::NotImplemented, ::NotImplemented) = x
Base.:*(::NotImplemented, ::Zero) = Zero()
Base.:*(::Zero, ::NotImplemented) = Zero()
for T in (:DoesNotExist, :One, :AbstractThunk, :Composite, :Any)
@eval Base.:+(x::NotImplemented, ::$T) = x
@eval Base.:+(::$T, x::NotImplemented) = x
@eval Base.:*(x::NotImplemented, ::$T) = x
end

# required for `@scalar_rule`
Base.muladd(x::NotImplemented, y, z) = x
Base.muladd(::NotImplemented, ::Zero, z) = z
Base.muladd(x::NotImplemented, y, ::Zero) = x
Base.muladd(::NotImplemented, ::Zero, ::Zero) = Zero()

Base.muladd(x, y::NotImplemented, z) = y
Base.muladd(::Zero, ::NotImplemented, z) = z
Base.muladd(x, y::NotImplemented, ::Zero) = y
Base.muladd(::Zero, ::NotImplemented, ::Zero) = Zero()

Base.muladd(x, y, z::NotImplemented) = z
Base.muladd(::Zero, y, z::NotImplemented) = z
Base.muladd(x, ::Zero, z::NotImplemented) = z
Base.muladd(::Zero, ::Zero, z::NotImplemented) = z

Base.muladd(x::NotImplemented, ::NotImplemented, z) = x
Base.muladd(x::NotImplemented, ::NotImplemented, ::Zero) = x

Base.muladd(x::NotImplemented, y, ::NotImplemented) = x
Base.muladd(::NotImplemented, ::Zero, z::NotImplemented) = z

Base.muladd(x, y::NotImplemented, ::NotImplemented) = y
Base.muladd(::Zero, ::NotImplemented, z::NotImplemented) = z

Base.muladd(x::NotImplemented, ::NotImplemented, ::NotImplemented) = x
LinearAlgebra.dot(::NotImplemented, ::Zero) = Zero()
LinearAlgebra.dot(::Zero, ::NotImplemented) = Zero()

# similar to `DoesNotExist`, `Zero` wins `*` and `NotImplemented` wins `+`
Base.:+(x::NotImplemented, ::Zero) = throw(NotImplementedException(x))
Base.:+(::Zero, x::NotImplemented) = throw(NotImplementedException(x))
# other common operations throw an exception
Base.:+(x::NotImplemented) = throw(NotImplementedException(x))
Base.:-(x::NotImplemented) = throw(NotImplementedException(x))
Base.:-(x::NotImplemented, ::Zero) = throw(NotImplementedException(x))
Base.:-(::Zero, x::NotImplemented) = throw(NotImplementedException(x))
Base.:*(::NotImplemented, ::Zero) = Zero()
Base.:*(::Zero, ::NotImplemented) = Zero()
LinearAlgebra.dot(::NotImplemented, ::Zero) = Zero()
LinearAlgebra.dot(::Zero, ::NotImplemented) = Zero()
Base.:-(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
Base.:*(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
function LinearAlgebra.dot(x::NotImplemented, ::NotImplemented)
return throw(NotImplementedException(x))
end
for T in (:DoesNotExist, :One, :AbstractThunk, :Composite, :Any)
@eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@eval Base.:-(::$T, x::NotImplemented) = throw(NotImplementedException(x))
@eval Base.:*(::$T, x::NotImplemented) = throw(NotImplementedException(x))
@eval LinearAlgebra.dot(x::NotImplemented, ::$T) = throw(NotImplementedException(x))
@eval LinearAlgebra.dot(::$T, x::NotImplemented) = throw(NotImplementedException(x))
end

Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
Base.:-(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
Expand Down
35 changes: 35 additions & 0 deletions test/differentials/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,41 @@ end
d_sum = Composite{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0))
@test d1 + d2 == d_sum
end

@testset "Fields of type NotImplemented" begin
CFoo = Composite{Foo}
a = CFoo(x=1.5)
b = CFoo(x=@not_implemented(""))
for (x, y) in ((a, b), (b, a), (b, b))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need to test exhaustively? IMO we could just test Composite{Foo} with (a, b), (b, a) here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(b, b) tests +(::NotImplemented, ::NotImplemented) but, as the other two cases, in principle this is covered by the NotImplemented tests now as well. I kept the tests of the original PR since this was the actual problem that caused errors in ChainRulesTestUtils. Would you like me to remove the Composite tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be my preference, but it is not strong and I leave it up to you to make a decision. Feel free to merge as is!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not necessary to remove the tests, I would like to keep them since they check the actual problem that this PR tries to solve. I assume that this is an advantage if, e.g., at some point a more restrictive policy regarding summation of NotImplemented is adopted.

z = x + y
@test z isa CFoo
@test z.x isa ChainRulesCore.NotImplemented
end

a = Composite{Tuple}(1.5)
b = Composite{Tuple}(@not_implemented(""))
for (x, y) in ((a, b), (b, a), (b, b))
z = x + y
@test z isa Composite{Tuple}
@test first(z) isa ChainRulesCore.NotImplemented
end

a = Composite{NamedTuple{(:x,)}}(x=1.5)
b = Composite{NamedTuple{(:x,)}}(x=@not_implemented(""))
for (x, y) in ((a, b), (b, a), (b, b))
z = x + y
@test z isa Composite{NamedTuple{(:x,)}}
@test z.x isa ChainRulesCore.NotImplemented
end

a = Composite{Dict}(Dict(:x => 1.5))
b = Composite{Dict}(Dict(:x => @not_implemented("")))
for (x, y) in ((a, b), (b, a), (b, b))
z = x + y
@test z isa Composite{Dict}
@test z[:x] isa ChainRulesCore.NotImplemented
end
end
end

@testset "+ with Primals" begin
Expand Down
22 changes: 11 additions & 11 deletions test/differentials/notimplemented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
@test muladd(Zero(), y, ni) === ni
@test muladd(x, Zero(), ni) === ni
@test muladd(Zero(), Zero(), ni) === ni
@test ni + rand() === ni
@test ni + Zero() === ni
@test ni + DoesNotExist() === ni
@test ni + One() === ni
@test ni + @thunk(x^2) === ni
@test rand() + ni === ni
@test Zero() + ni === ni
@test DoesNotExist() + ni === ni
@test One() + ni === ni
@test @thunk(x^2) + ni === ni
@test ni + ni2 === ni
@test ni * rand() === ni
@test ni * Zero() == Zero()
@test Zero() * ni == Zero()
Expand All @@ -40,17 +51,6 @@
E = ChainRulesCore.NotImplementedException
@test_throws E extern(ni)
@test_throws E +ni
@test_throws E ni + rand()
@test_throws E ni + Zero()
@test_throws E ni + DoesNotExist()
@test_throws E ni + One()
@test_throws E ni + @thunk(x^2)
@test_throws E rand() + ni
@test_throws E Zero() + ni
@test_throws E DoesNotExist() + ni
@test_throws E One() + ni
@test_throws E @thunk(x^2) + ni
@test_throws E ni + ni2
@test_throws E -ni
@test_throws E ni - rand()
@test_throws E ni - Zero()
Expand Down