Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.12"
version = "1.12.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
1 change: 1 addition & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ ProjectTo(::Any) = identity

# Thunks
(project::ProjectTo)(dx::Thunk) = Thunk(project ∘ dx.f)
(project::ProjectTo)(dx::InplaceableThunk) = Thunk(project ∘ dx.val.f)

# Zero
ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector,
Expand Down
7 changes: 7 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ struct NoSuperType end
@test unthunk(pth) === 6.0 + 0.0im
end

@testset "InplaceableThunk" begin
it = InplaceableThunk(x -> x + 6, @thunk 1 + 2 + 3)
pt = ProjectTo(4 + 5im)(it)
@test pt isa Thunk
@test unthunk(pt) === 6.0 + 0.0im
end

@testset "Tangent" begin
x = 1:3.0
dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent())
Expand Down