Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.12"
version = "1.13"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
13 changes: 8 additions & 5 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,14 @@ end

using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec

# UniformScaling can represent its own cotangent
ProjectTo(x::UniformScaling) = ProjectTo{UniformScaling}(; λ=ProjectTo(x.λ))
ProjectTo(x::UniformScaling{Bool}) = ProjectTo(false)
(pr::ProjectTo{UniformScaling})(dx::UniformScaling) = UniformScaling(pr.λ(dx.λ))
(pr::ProjectTo{UniformScaling})(dx::Tangent{<:UniformScaling}) = UniformScaling(pr.λ(dx.λ))
if VERSION >= v"1.6"
# UniformScaling can represent its own cotangent
# but shouldn't on Julia 1.0, as rules in CR.jl were added only after mioving to 1.6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# but shouldn't on Julia 1.0, as rules in CR.jl were added only after mioving to 1.6
# but shouldn't on Julia 1.0, as rules in CR.jl were added only after moving to 1.6

ProjectTo(x::UniformScaling) = ProjectTo{UniformScaling}(; λ=ProjectTo(x.λ))
ProjectTo(x::UniformScaling{Bool}) = ProjectTo(false)
(pr::ProjectTo{UniformScaling})(dx::UniformScaling) = UniformScaling(pr.λ(dx.λ))
(pr::ProjectTo{UniformScaling})(dx::Tangent{<:UniformScaling}) = UniformScaling(pr.λ(dx.λ))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
(pr::ProjectTo{UniformScaling})(dx::Tangent{<:UniformScaling}) = UniformScaling(pr.λ(dx.λ))
function (pr::ProjectTo{UniformScaling})(dx::Tangent{<:UniformScaling})
return UniformScaling(pr.λ(dx.λ))
end

end

# Row vectors
ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x)))
Expand Down
2 changes: 1 addition & 1 deletion test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ struct NoSuperType end
##### `LinearAlgebra`
#####

@testset "UniformScaling" begin
VERSION >= v"1.6" && @testset "UniformScaling" begin
@test ProjectTo(I)(123) === NoTangent()
@test ProjectTo(2 * I)(I * 3im) === 0.0 * I
@test ProjectTo((4 + 5im) * I)(Tangent{typeof(im * I)}(; λ = 6)) === (6.0 + 0.0im) * I
Expand Down