From a518e5b2340ed6c5e21913873e0b00489d6caf18 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Feb 2022 15:16:44 +0000 Subject: [PATCH] backport #571 --- Project.toml | 4 +- src/ChainRules.jl | 1 + src/rulesets/LinearAlgebra/uniformscaling.jl | 94 +++++++++++++++++++ test/rulesets/LinearAlgebra/uniformscaling.jl | 35 +++++++ test/runtests.jl | 1 + 5 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 src/rulesets/LinearAlgebra/uniformscaling.jl create mode 100644 test/rulesets/LinearAlgebra/uniformscaling.jl diff --git a/Project.toml b/Project.toml index 2dfec9103..e2041c08a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.21" +version = "1.21.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "1.11.5" +ChainRulesCore = "1.12" ChainRulesTestUtils = "1" Compat = "3.35" FiniteDifferences = "0.12.20" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index f2806fd3d..3b743350c 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -47,6 +47,7 @@ include("rulesets/LinearAlgebra/matfun.jl") include("rulesets/LinearAlgebra/structured.jl") include("rulesets/LinearAlgebra/symmetric.jl") include("rulesets/LinearAlgebra/factorization.jl") +include("rulesets/LinearAlgebra/uniformscaling.jl") include("rulesets/Random/random.jl") diff --git a/src/rulesets/LinearAlgebra/uniformscaling.jl b/src/rulesets/LinearAlgebra/uniformscaling.jl new file mode 100644 index 000000000..436aeac06 --- /dev/null +++ b/src/rulesets/LinearAlgebra/uniformscaling.jl @@ -0,0 +1,94 @@ +##### +##### constructor +##### + +function rrule(::Type{T}, x::Number) where {T<:UniformScaling} + UniformScaling_back(dx) = (NoTangent(), ProjectTo(x)(unthunk(dx).λ)) + return T(x), UniformScaling_back +end + +##### +##### `+` +##### + +function frule((_, Δx, ΔJ), ::typeof(+), x::AbstractMatrix, J::UniformScaling) + return x + J, Δx + (zero(J) + ΔJ) # This (0 + ΔJ) allows for ΔJ::Tangent{UniformScaling} +end + +function frule((_, ΔJ, Δx), ::typeof(+), J::UniformScaling, x::AbstractMatrix) + return J + x, (zero(J) + ΔJ) + Δx +end + +function rrule(::typeof(+), x::AbstractMatrix, J::UniformScaling) + project_x = ProjectTo(x) + project_J = ProjectTo(J) + function plus_back(dy) + dx = unthunk(dy) + return (NoTangent(), project_x(dx), project_J(I * tr(dx))) + end + return x + J, plus_back +end + +function rrule(::typeof(+), J::UniformScaling, x::AbstractMatrix) + y, back = rrule(+, x, J) + function plus_back_2(dy) + df, dx, dJ = back(dy) + return (df, dJ, dx) + end + return y, plus_back_2 +end + +##### +##### `-` +##### + +function frule((_, Δx, ΔJ), ::typeof(-), x::AbstractMatrix, J::UniformScaling) + return x - J, Δx - (zero(J) + ΔJ) +end + +function frule((_, ΔJ, Δx), ::typeof(-), J::UniformScaling, x::AbstractMatrix) + return J - x, (zero(J) + ΔJ) - Δx +end + +function rrule(::typeof(-), x::AbstractMatrix, J::UniformScaling) + y, back = rrule(+, x, -J) + project_J = ProjectTo(J) + function minus_back_1(dy) + df, dx, dJ = back(dy) + return (df, dx, project_J(-dJ)) # re-project as -true isa Int + end + return y, minus_back_1 +end + + +function rrule(::typeof(-), J::UniformScaling, x::AbstractMatrix) + project_x = ProjectTo(x) + project_J = ProjectTo(J) + function minus_back_2(dy) + dx = -unthunk(dy) + return (NoTangent(), project_J(-tr(dx) * I), project_x(dx)) + end + return J - x, minus_back_2 +end + +##### +##### `Matrix` +##### + +function rrule(::Type{T}, I::UniformScaling{<:Bool}, (m, n)) where {T<:AbstractMatrix} + Matrix_back_I(dy) = (NoTangent(), NoTangent(), NoTangent()) + return T(I, m, n), Matrix_back_I +end + +function rrule(::Type{T}, J::UniformScaling, (m, n)) where {T<:AbstractMatrix} + project_J = ProjectTo(J) + function Matrix_back_I(dy) + dJ = if m == n + project_J(I * tr(unthunk(dy))) + else + project_J(I * sum(diag(unthunk(dy)))) + end + return (NoTangent(), dJ, NoTangent()) + end + return T(J, m, n), Matrix_back_I +end diff --git a/test/rulesets/LinearAlgebra/uniformscaling.jl b/test/rulesets/LinearAlgebra/uniformscaling.jl new file mode 100644 index 000000000..320a98cdc --- /dev/null +++ b/test/rulesets/LinearAlgebra/uniformscaling.jl @@ -0,0 +1,35 @@ +@testset "UniformScaling rules" begin + @testset "constructor" begin + test_rrule(UniformScaling, rand()) + end + + @testset "+" begin + # Forward + test_frule(+, rand(3,3), I * rand(ComplexF64)) + test_frule(+, I, rand(3,3)) + + # Reverse + test_rrule(+, rand(3,3), I) + test_rrule(+, rand(3,3), I * rand(ComplexF64)) + test_rrule(+, I, rand(3,3)) + test_rrule(+, I * rand(), rand(ComplexF64, 3,3)) + end + + @testset "-" begin + # Forward + test_frule(-, rand(3,3), I * rand(ComplexF64)) + test_frule(-, I, rand(3,3)) + + # Reverse + test_rrule(-, rand(3,3), I) + test_rrule(-, rand(3,3), I * rand(ComplexF64)) + test_rrule(-, I, rand(3,3)) + test_rrule(-, I * rand(), rand(ComplexF64, 3,3)) + end + + @testset "Matrix" begin + test_rrule(Matrix, I, (2, 2)) + test_rrule(Matrix{ComplexF64}, rand()*I, (3, 3)) + test_rrule(Matrix, rand(ComplexF64)*I, (2, 4)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 61548644b..817abbccc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,6 +70,7 @@ end include_test("rulesets/LinearAlgebra/factorization.jl") include_test("rulesets/LinearAlgebra/blas.jl") include_test("rulesets/LinearAlgebra/lapack.jl") + include_test("rulesets/LinearAlgebra/uniformscaling.jl") println()