@@ -447,6 +447,65 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache,
447447 return (nothing ,nothing )
448448end
449449
450+ const EnzymeTriangulars = Union{
451+ UpperTriangular,
452+ LowerTriangular,
453+ UnitUpperTriangular,
454+ UnitLowerTriangular
455+ }
456+
457+ function EnzymeRules. augmented_primal (
458+ config,
459+ func:: Const{typeof(ldiv!)} ,
460+ :: Type{RT} ,
461+ Y:: Annotation{YT} ,
462+ A:: Annotation{AT} ,
463+ B:: Annotation{BT}
464+ ) where {RT, YT <: Array , AT <: EnzymeTriangulars , BT <: Array }
465+ cache_Y = EnzymeRules. overwritten (config)[1 ] ? copy (Y. val) : Y. val
466+ cache_A = EnzymeRules. overwritten (config)[2 ] ? copy (A. val) : A. val
467+ cache_A = compute_lu_cache (cache_A, B. val)
468+ cache_B = EnzymeRules. overwritten (config)[3 ] ? copy (B. val) : nothing
469+ primal = EnzymeRules. needs_primal (config) ? Y. val : nothing
470+ shadow = EnzymeRules. needs_shadow (config) ? Y. dval : nothing
471+ func. val (Y. val, A. val, B. val)
472+ return EnzymeRules. AugmentedReturn {typeof(primal), typeof(shadow), Any} (
473+ primal, shadow, (cache_Y, cache_A, cache_B))
474+ end
475+
476+ function EnzymeRules. reverse (
477+ config,
478+ func:: Const{typeof(ldiv!)} ,
479+ :: Type{RT} ,
480+ cache,
481+ Y:: Annotation{YT} ,
482+ A:: Annotation{AT} ,
483+ B:: Annotation{BT}
484+ ) where {YT <: Array , RT, AT <: EnzymeTriangulars , BT <: Array }
485+ if ! isa (Y, Const)
486+ (cache_Yout, cache_A, cache_B) = cache
487+ for b in 1 : EnzymeRules. width (config)
488+ dY = EnzymeRules. width (config) == 1 ? Y. dval : Y. dval[b]
489+ z = adjoint (cache_A) \ dY
490+ if ! isa (B, Const)
491+ dB = EnzymeRules. width (config) == 1 ? B. dval : B. dval[b]
492+ dB .+ = z
493+ end
494+ if ! isa (A, Const)
495+ dA = EnzymeRules. width (config) == 1 ? A. dval : A. dval[b]
496+ dA. data .- = _zero_unused_elements! (AT (z * adjoint (cache_Yout)))
497+ end
498+ dY .= zero (eltype (dY))
499+ end
500+ end
501+ return (nothing , nothing , nothing )
502+ end
503+
504+ _zero_unused_elements! (A:: UpperTriangular ) = triu! (A. data)
505+ _zero_unused_elements! (A:: LowerTriangular ) = tril! (A. data)
506+ _zero_unused_elements! (A:: UnitUpperTriangular ) = triu! (A. data, 1 )
507+ _zero_unused_elements! (A:: UnitLowerTriangular ) = tril! (A. data, - 1 )
508+
450509@static if VERSION >= v " 1.7-"
451510# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float)
452511function EnzymeRules. augmented_primal (config, func:: Const{typeof(Base.hvcat_fill!)} , :: Type{RT} , out:: Annotation{AT} , inp:: Annotation{BT} ) where {RT, AT <: Array , BT <: Tuple }
0 commit comments