From 116935586942f0804c74f503acf2a875fa39fc70 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 28 Apr 2023 12:40:00 +0200 Subject: [PATCH 1/8] Handle matrix times matrix = vector case --- Project.toml | 2 +- src/derivatives/linalg/arithmetic.jl | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fdcd5f7f..cd0673fa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ReverseDiff" uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -version = "1.14.5" +version = "1.14.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/derivatives/linalg/arithmetic.jl b/src/derivatives/linalg/arithmetic.jl index 271af226..be0c2a80 100644 --- a/src/derivatives/linalg/arithmetic.jl +++ b/src/derivatives/linalg/arithmetic.jl @@ -273,6 +273,10 @@ function reverse_mul!(output, output_deriv, a, b, a_tmp, b_tmp) istracked(a) && increment_deriv!(a, mul!(a_tmp, output_deriv, transpose(value(b)))) istracked(b) && increment_deriv!(b, mul!(b_tmp, transpose(value(a)), output_deriv)) end +function reverse_mul!(output, output_deriv::AbstractMatrix, a, b::AbstractMatrix, a_tmp::AbstractVector, b_tmp) + istracked(a) && increment_deriv!(a, mul!(reshape(a_tmp, :, 1), output_deriv, transpose(value(b)))) + istracked(b) && increment_deriv!(b, mul!(b_tmp, transpose(value(a)), output_deriv)) +end for (f, F) in ((:transpose, :Transpose), (:adjoint, :Adjoint)) @eval begin From 9d6466e48f9071dd10cb5911483047f6c89a9219 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 4 May 2023 18:39:56 +0200 Subject: [PATCH 2/8] add test --- test/derivatives/LinAlgTests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/derivatives/LinAlgTests.jl b/test/derivatives/LinAlgTests.jl index ee2b2ef7..f3679497 100644 --- a/test/derivatives/LinAlgTests.jl +++ b/test/derivatives/LinAlgTests.jl @@ -223,8 +223,11 @@ for f in ( test_arr2num(f, x, tp) end +vec_to_hermitian = (v) -> begin A = I - 2 * v * collect(v'); A = collect(A') * A end; + for f in ( y -> vec(y)' * Matrix{Float64}(I, length(y), length(y)) * vec(y), + y -> norm(vec_to_hermitian(y)), ) test_println("Array -> Number functions", f) test_arr2num(f, x, tp, ignore_tape_length=true) From c41c6e0cae9bd97fc43766f572106977b77dcb83 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 5 May 2023 16:29:39 +0200 Subject: [PATCH 3/8] actually cover it --- test/api/GradientTests.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/api/GradientTests.jl b/test/api/GradientTests.jl index 858bf5d7..5f37be32 100644 --- a/test/api/GradientTests.jl +++ b/test/api/GradientTests.jl @@ -187,6 +187,12 @@ for f in DiffTests.VECTOR_TO_NUMBER_FUNCS test_unary_gradient(f, rand(5)) end +vec_to_hermitian = (v) -> begin A = I - 2 * v * collect(v'); A = collect(A') * A end; +for f in (y -> norm(vec_to_hermitian(y)),) + test_println("VECTOR_TO_NUMBER_FUNCS", f) + test_unary_gradient(f, rand(5)) +end + for f in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS test_println("TERNARY_MATRIX_TO_NUMBER_FUNCS", f) test_ternary_gradient(f, rand(5, 5), rand(5, 5), rand(5, 5)) From e0dc75db5daaa472a89e06c97f092d1f8117089d Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 6 May 2023 10:52:26 +0200 Subject: [PATCH 4/8] use LinearAlgebra --- test/api/GradientTests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/api/GradientTests.jl b/test/api/GradientTests.jl index 5f37be32..56565374 100644 --- a/test/api/GradientTests.jl +++ b/test/api/GradientTests.jl @@ -1,6 +1,6 @@ module GradientTests -using DiffTests, ForwardDiff, ReverseDiff, Test +using DiffTests, ForwardDiff, ReverseDiff, Test, LinearAlgebra include(joinpath(dirname(@__FILE__), "../utils.jl")) From 29f94fda40df6cfada756a790b62267a5ee41f3c Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sun, 7 May 2023 09:17:43 +0200 Subject: [PATCH 5/8] Apply suggestions from code review Co-authored-by: David Widmann --- test/api/GradientTests.jl | 10 ++++++---- test/derivatives/LinAlgTests.jl | 8 ++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/test/api/GradientTests.jl b/test/api/GradientTests.jl index 56565374..f76ea68b 100644 --- a/test/api/GradientTests.jl +++ b/test/api/GradientTests.jl @@ -187,11 +187,13 @@ for f in DiffTests.VECTOR_TO_NUMBER_FUNCS test_unary_gradient(f, rand(5)) end -vec_to_hermitian = (v) -> begin A = I - 2 * v * collect(v'); A = collect(A') * A end; -for f in (y -> norm(vec_to_hermitian(y)),) - test_println("VECTOR_TO_NUMBER_FUNCS", f) - test_unary_gradient(f, rand(5)) +# PR #227 +function norm_hermitian(v) + A = I - 2 * v * v' + return norm(A' * A) end +test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian) +test_unary_gradient(norm_hermitian, rand(5)) for f in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS test_println("TERNARY_MATRIX_TO_NUMBER_FUNCS", f) diff --git a/test/derivatives/LinAlgTests.jl b/test/derivatives/LinAlgTests.jl index f3679497..a1fa508e 100644 --- a/test/derivatives/LinAlgTests.jl +++ b/test/derivatives/LinAlgTests.jl @@ -223,11 +223,15 @@ for f in ( test_arr2num(f, x, tp) end -vec_to_hermitian = (v) -> begin A = I - 2 * v * collect(v'); A = collect(A') * A end; +# PR #227 +function norm_hermitian(v) + A = I - 2 * v * v' + return norm(A' * A) +end for f in ( y -> vec(y)' * Matrix{Float64}(I, length(y), length(y)) * vec(y), - y -> norm(vec_to_hermitian(y)), + norm_hermitian, ) test_println("Array -> Number functions", f) test_arr2num(f, x, tp, ignore_tape_length=true) From 1bbfe2868391de726dbb808d08a8fcbde139a0ea Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 8 May 2023 10:59:11 +0200 Subject: [PATCH 6/8] avoid ambiguity, add tests --- src/derivatives/linalg/arithmetic.jl | 12 +++++++----- test/api/GradientTests.jl | 22 ++++++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/derivatives/linalg/arithmetic.jl b/src/derivatives/linalg/arithmetic.jl index be0c2a80..91f99fad 100644 --- a/src/derivatives/linalg/arithmetic.jl +++ b/src/derivatives/linalg/arithmetic.jl @@ -270,11 +270,13 @@ end # a * b function reverse_mul!(output, output_deriv, a, b, a_tmp, b_tmp) - istracked(a) && increment_deriv!(a, mul!(a_tmp, output_deriv, transpose(value(b)))) - istracked(b) && increment_deriv!(b, mul!(b_tmp, transpose(value(a)), output_deriv)) -end -function reverse_mul!(output, output_deriv::AbstractMatrix, a, b::AbstractMatrix, a_tmp::AbstractVector, b_tmp) - istracked(a) && increment_deriv!(a, mul!(reshape(a_tmp, :, 1), output_deriv, transpose(value(b)))) + if istracked(a) + if a_tmp isa AbstractVector && b isa AbstractMatrix + increment_deriv!(a, mul!(reshape(a_tmp, :, 1), output_deriv, transpose(value(b)))) + else + increment_deriv!(a, mul!(a_tmp, output_deriv, transpose(value(b)))) + end + end istracked(b) && increment_deriv!(b, mul!(b_tmp, transpose(value(a)), output_deriv)) end diff --git a/test/api/GradientTests.jl b/test/api/GradientTests.jl index f76ea68b..2d7ea1f4 100644 --- a/test/api/GradientTests.jl +++ b/test/api/GradientTests.jl @@ -188,12 +188,22 @@ for f in DiffTests.VECTOR_TO_NUMBER_FUNCS end # PR #227 -function norm_hermitian(v) - A = I - 2 * v * v' - return norm(A' * A) -end -test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian) -test_unary_gradient(norm_hermitian, rand(5)) +norm_hermitian1(v) = (A = I - 2 * v * v'; norm(A' * A)) +norm_hermitian2(v) = (A = I - 2 * v * transpose(v); norm(transpose(A) * A)) +norm_hermitian3(v) = (A = I - 2 * v * collect(v'); norm(collect(A') * A)) +norm_hermitian4(v) = (A = I - 2 * v * v'; norm(transpose(A) * A)) +norm_hermitian5(v) = (A = I - 2 * v * transpose(v); norm(A' * A)) + +test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian1) +test_unary_gradient(norm_hermitian1, rand(5)) +test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian2) +test_unary_gradient(norm_hermitian2, rand(5)) +test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian3) +test_unary_gradient(norm_hermitian3, rand(5)) +test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian4) +test_unary_gradient(norm_hermitian4, rand(5)) +test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian5) +test_unary_gradient(norm_hermitian5, rand(5)) for f in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS test_println("TERNARY_MATRIX_TO_NUMBER_FUNCS", f) From 7b744e78c6e455c45065010e3dc40be38df8b231 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 8 May 2023 15:18:54 +0200 Subject: [PATCH 7/8] Update test/api/GradientTests.jl Co-authored-by: David Widmann --- test/api/GradientTests.jl | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/test/api/GradientTests.jl b/test/api/GradientTests.jl index 2d7ea1f4..a01fb2c9 100644 --- a/test/api/GradientTests.jl +++ b/test/api/GradientTests.jl @@ -194,16 +194,10 @@ norm_hermitian3(v) = (A = I - 2 * v * collect(v'); norm(collect(A') * A)) norm_hermitian4(v) = (A = I - 2 * v * v'; norm(transpose(A) * A)) norm_hermitian5(v) = (A = I - 2 * v * transpose(v); norm(A' * A)) -test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian1) -test_unary_gradient(norm_hermitian1, rand(5)) -test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian2) -test_unary_gradient(norm_hermitian2, rand(5)) -test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian3) -test_unary_gradient(norm_hermitian3, rand(5)) -test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian4) -test_unary_gradient(norm_hermitian4, rand(5)) -test_println("VECTOR_TO_NUMBER_FUNCS", norm_hermitian5) -test_unary_gradient(norm_hermitian5, rand(5)) +for f in (norm_hermitian1, norm_hermitian2, norm_hermitian3, norm_hermitian4, norm_hermitian5) + test_println("VECTOR_TO_NUMBER_FUNCS", f) + test_unary_gradient(f, rand(5)) +end for f in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS test_println("TERNARY_MATRIX_TO_NUMBER_FUNCS", f) From 90f676b9754bf7bb797dca8762ee21307bf21d34 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Mon, 8 May 2023 16:14:47 +0200 Subject: [PATCH 8/8] fix --- src/derivatives/linalg/arithmetic.jl | 14 ++++++++++++-- test/api/GradientTests.jl | 4 +++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/derivatives/linalg/arithmetic.jl b/src/derivatives/linalg/arithmetic.jl index 91f99fad..b3e91a62 100644 --- a/src/derivatives/linalg/arithmetic.jl +++ b/src/derivatives/linalg/arithmetic.jl @@ -272,6 +272,10 @@ end function reverse_mul!(output, output_deriv, a, b, a_tmp, b_tmp) if istracked(a) if a_tmp isa AbstractVector && b isa AbstractMatrix + # this branch is required for scalar-valued functions that + # involve outer-products of vectors, for such functions, the target + # a_temp is a vector, but when b is a matrix, we cannot multiply into a vector, + # so need to reshape memory to look like matrix (see PositiveFactorizations.jl) increment_deriv!(a, mul!(reshape(a_tmp, :, 1), output_deriv, transpose(value(b)))) else increment_deriv!(a, mul!(a_tmp, output_deriv, transpose(value(b)))) @@ -285,8 +289,14 @@ for (f, F) in ((:transpose, :Transpose), (:adjoint, :Adjoint)) # a * f(b) function reverse_mul!(output, output_deriv, a, b::$F, a_tmp, b_tmp) _b = ($f)(b) - istracked(a) && increment_deriv!(a, mul!(a_tmp, output_deriv, mulargvalue(b))) - istracked(_b) && increment_deriv!(_b, ($f)(mul!(b_tmp, ($f)(output_deriv), value(a)))) + if istracked(a) + if a_tmp isa AbstractVector + increment_deriv!(a, mul!(reshape(a_tmp, :, 1), output_deriv, mulargvalue(_b))) + else + increment_deriv!(a, mul!(a_tmp, output_deriv, mulargvalue(b))) + end + end + istracked(_b) && increment_deriv!(_b, ($f)(mul!(($f)(b_tmp), ($f)(output_deriv), value(a)))) end # f(a) * b function reverse_mul!(output, output_deriv, a::$F, b, a_tmp, b_tmp) diff --git a/test/api/GradientTests.jl b/test/api/GradientTests.jl index a01fb2c9..741bd049 100644 --- a/test/api/GradientTests.jl +++ b/test/api/GradientTests.jl @@ -193,8 +193,10 @@ norm_hermitian2(v) = (A = I - 2 * v * transpose(v); norm(transpose(A) * A)) norm_hermitian3(v) = (A = I - 2 * v * collect(v'); norm(collect(A') * A)) norm_hermitian4(v) = (A = I - 2 * v * v'; norm(transpose(A) * A)) norm_hermitian5(v) = (A = I - 2 * v * transpose(v); norm(A' * A)) +norm_hermitian6(v) = (A = (v'v)*I - 2 * v * v'; norm(A' * A)) -for f in (norm_hermitian1, norm_hermitian2, norm_hermitian3, norm_hermitian4, norm_hermitian5) +for f in (norm_hermitian1, norm_hermitian2, norm_hermitian3, + norm_hermitian4, norm_hermitian5, norm_hermitian6) test_println("VECTOR_TO_NUMBER_FUNCS", f) test_unary_gradient(f, rand(5)) end