Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.21"
version = "1.21.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1.11.5"
ChainRulesCore = "1.12"
ChainRulesTestUtils = "1"
Compat = "3.35"
FiniteDifferences = "0.12.20"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include("rulesets/LinearAlgebra/matfun.jl")
include("rulesets/LinearAlgebra/structured.jl")
include("rulesets/LinearAlgebra/symmetric.jl")
include("rulesets/LinearAlgebra/factorization.jl")
include("rulesets/LinearAlgebra/uniformscaling.jl")

include("rulesets/Random/random.jl")

Expand Down
94 changes: 94 additions & 0 deletions src/rulesets/LinearAlgebra/uniformscaling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#####
##### constructor
#####

function rrule(::Type{T}, x::Number) where {T<:UniformScaling}
UniformScaling_back(dx) = (NoTangent(), ProjectTo(x)(unthunk(dx).λ))
return T(x), UniformScaling_back
end

#####
##### `+`
#####

function frule((_, Δx, ΔJ), ::typeof(+), x::AbstractMatrix, J::UniformScaling)
return x + J, Δx + (zero(J) + ΔJ) # This (0 + ΔJ) allows for ΔJ::Tangent{UniformScaling}
end

function frule((_, ΔJ, Δx), ::typeof(+), J::UniformScaling, x::AbstractMatrix)
return J + x, (zero(J) + ΔJ) + Δx
end

function rrule(::typeof(+), x::AbstractMatrix, J::UniformScaling)
project_x = ProjectTo(x)
project_J = ProjectTo(J)
function plus_back(dy)
dx = unthunk(dy)
return (NoTangent(), project_x(dx), project_J(I * tr(dx)))
end
return x + J, plus_back
end

function rrule(::typeof(+), J::UniformScaling, x::AbstractMatrix)
y, back = rrule(+, x, J)
function plus_back_2(dy)
df, dx, dJ = back(dy)
return (df, dJ, dx)
end
return y, plus_back_2
end

#####
##### `-`
#####

function frule((_, Δx, ΔJ), ::typeof(-), x::AbstractMatrix, J::UniformScaling)
return x - J, Δx - (zero(J) + ΔJ)
end

function frule((_, ΔJ, Δx), ::typeof(-), J::UniformScaling, x::AbstractMatrix)
return J - x, (zero(J) + ΔJ) - Δx
end

function rrule(::typeof(-), x::AbstractMatrix, J::UniformScaling)
y, back = rrule(+, x, -J)
project_J = ProjectTo(J)
function minus_back_1(dy)
df, dx, dJ = back(dy)
return (df, dx, project_J(-dJ)) # re-project as -true isa Int
end
return y, minus_back_1
end


function rrule(::typeof(-), J::UniformScaling, x::AbstractMatrix)
project_x = ProjectTo(x)
project_J = ProjectTo(J)
function minus_back_2(dy)
dx = -unthunk(dy)
return (NoTangent(), project_J(-tr(dx) * I), project_x(dx))
end
return J - x, minus_back_2
end

#####
##### `Matrix`
#####

function rrule(::Type{T}, I::UniformScaling{<:Bool}, (m, n)) where {T<:AbstractMatrix}
Matrix_back_I(dy) = (NoTangent(), NoTangent(), NoTangent())
return T(I, m, n), Matrix_back_I
end

function rrule(::Type{T}, J::UniformScaling, (m, n)) where {T<:AbstractMatrix}
project_J = ProjectTo(J)
function Matrix_back_I(dy)
dJ = if m == n
project_J(I * tr(unthunk(dy)))
else
project_J(I * sum(diag(unthunk(dy))))
end
return (NoTangent(), dJ, NoTangent())
end
return T(J, m, n), Matrix_back_I
end
35 changes: 35 additions & 0 deletions test/rulesets/LinearAlgebra/uniformscaling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@testset "UniformScaling rules" begin
@testset "constructor" begin
test_rrule(UniformScaling, rand())
end

@testset "+" begin
# Forward
test_frule(+, rand(3,3), I * rand(ComplexF64))
test_frule(+, I, rand(3,3))

# Reverse
test_rrule(+, rand(3,3), I)
test_rrule(+, rand(3,3), I * rand(ComplexF64))
test_rrule(+, I, rand(3,3))
test_rrule(+, I * rand(), rand(ComplexF64, 3,3))
end

@testset "-" begin
# Forward
test_frule(-, rand(3,3), I * rand(ComplexF64))
test_frule(-, I, rand(3,3))

# Reverse
test_rrule(-, rand(3,3), I)
test_rrule(-, rand(3,3), I * rand(ComplexF64))
test_rrule(-, I, rand(3,3))
test_rrule(-, I * rand(), rand(ComplexF64, 3,3))
end

@testset "Matrix" begin
test_rrule(Matrix, I, (2, 2))
test_rrule(Matrix{ComplexF64}, rand()*I, (3, 3))
test_rrule(Matrix, rand(ComplexF64)*I, (2, 4))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ end
include_test("rulesets/LinearAlgebra/factorization.jl")
include_test("rulesets/LinearAlgebra/blas.jl")
include_test("rulesets/LinearAlgebra/lapack.jl")
include_test("rulesets/LinearAlgebra/uniformscaling.jl")

println()

Expand Down