diff --git a/Project.toml b/Project.toml index a7d74bf83..9f22f8637 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.12" +version = "1.12.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/projection.jl b/src/projection.jl index 5967c7ced..8eba26353 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -122,6 +122,7 @@ ProjectTo(::Any) = identity # Thunks (project::ProjectTo)(dx::Thunk) = Thunk(project ∘ dx.f) +(project::ProjectTo)(dx::InplaceableThunk) = project(dx.val) # Zero ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector, diff --git a/test/projection.jl b/test/projection.jl index a55d42909..3e70772ac 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -427,6 +427,13 @@ struct NoSuperType end @test unthunk(pth) === 6.0 + 0.0im end + @testset "InplaceableThunk" begin + it = InplaceableThunk(x -> x + 6, @thunk 1 + 2 + 3) + pt = ProjectTo(4 + 5im)(it) + @test pt isa Thunk + @test unthunk(pt) === 6.0 + 0.0im + end + @testset "Tangent" begin x = 1:3.0 dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent())