Skip to content

Commit e8ede63

Browse files
authored
Add rule for triangular solves (#1264)
1 parent 2c1fb5d commit e8ede63

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

src/internal_rules.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,65 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache,
447447
return (nothing,nothing)
448448
end
449449

450+
const EnzymeTriangulars = Union{
451+
UpperTriangular,
452+
LowerTriangular,
453+
UnitUpperTriangular,
454+
UnitLowerTriangular
455+
}
456+
457+
function EnzymeRules.augmented_primal(
458+
config,
459+
func::Const{typeof(ldiv!)},
460+
::Type{RT},
461+
Y::Annotation{YT},
462+
A::Annotation{AT},
463+
B::Annotation{BT}
464+
) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array}
465+
cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val
466+
cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val
467+
cache_A = compute_lu_cache(cache_A, B.val)
468+
cache_B = EnzymeRules.overwritten(config)[3] ? copy(B.val) : nothing
469+
primal = EnzymeRules.needs_primal(config) ? Y.val : nothing
470+
shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing
471+
func.val(Y.val, A.val, B.val)
472+
return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}(
473+
primal, shadow, (cache_Y, cache_A, cache_B))
474+
end
475+
476+
function EnzymeRules.reverse(
477+
config,
478+
func::Const{typeof(ldiv!)},
479+
::Type{RT},
480+
cache,
481+
Y::Annotation{YT},
482+
A::Annotation{AT},
483+
B::Annotation{BT}
484+
) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array}
485+
if !isa(Y, Const)
486+
(cache_Yout, cache_A, cache_B) = cache
487+
for b in 1:EnzymeRules.width(config)
488+
dY = EnzymeRules.width(config) == 1 ? Y.dval : Y.dval[b]
489+
z = adjoint(cache_A) \ dY
490+
if !isa(B, Const)
491+
dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b]
492+
dB .+= z
493+
end
494+
if !isa(A, Const)
495+
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]
496+
dA.data .-= _zero_unused_elements!(AT(z * adjoint(cache_Yout)))
497+
end
498+
dY .= zero(eltype(dY))
499+
end
500+
end
501+
return (nothing, nothing, nothing)
502+
end
503+
504+
_zero_unused_elements!(A::UpperTriangular) = triu!(A.data)
505+
_zero_unused_elements!(A::LowerTriangular) = tril!(A.data)
506+
_zero_unused_elements!(A::UnitUpperTriangular) = triu!(A.data, 1)
507+
_zero_unused_elements!(A::UnitLowerTriangular) = tril!(A.data, -1)
508+
450509
@static if VERSION >= v"1.7-"
451510
# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float)
452511
function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple}

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1919

2020
[compat]
2121
Aqua = "0.8"
22+
EnzymeTestUtils = "0.1.4"

test/internal_rules.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module InternalRules
22

33
using Enzyme
44
using Enzyme.EnzymeRules
5+
using EnzymeTestUtils
56
using FiniteDifferences
67
using LinearAlgebra
78
using SparseArrays
@@ -386,5 +387,29 @@ end
386387
@test isapprox(dA, dA_sym)
387388
end
388389
end
390+
391+
@testset "Linear solve for triangular matrices" begin
392+
@testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular),
393+
TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3))
394+
n = sizeB[1]
395+
M = rand(TE, n, n)
396+
B = rand(TE, sizeB...)
397+
Y = zeros(TE, sizeB...)
398+
A = T(M)
399+
@testset "test against EnzymeTestUtils through constructor" begin
400+
_A = T(A)
401+
function f!(Y, A, B, ::T) where T
402+
ldiv!(Y, T(A), B)
403+
return nothing
404+
end
405+
for TY in (Const, Duplicated, BatchDuplicated),
406+
TM in (Const, Duplicated, BatchDuplicated),
407+
TB in (Const, Duplicated, BatchDuplicated)
408+
are_activities_compatible(Const, TY, TM, TB) || continue
409+
test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const))
410+
end
411+
end
412+
end
413+
end
389414
end
390415
end # InternalRules

0 commit comments

Comments
 (0)