@@ -3,37 +3,47 @@ using Integrals
33if isdefined (Base, :get_extension )
44 using Zygote
55 import ChainRulesCore
6- import ChainRulesCore: NoTangent
6+ import ChainRulesCore: NoTangent, ProjectTo
77else
88 using .. Zygote
99 import .. Zygote. ChainRulesCore
10- import .. Zygote. ChainRulesCore: NoTangent
10+ import .. Zygote. ChainRulesCore: NoTangent, ProjectTo
1111end
1212ChainRulesCore. @non_differentiable Integrals. checkkwargs (kwargs... )
13+ ChainRulesCore. @non_differentiable Integrals. isinplace (f, n) # fixes #99
1314
1415function ChainRulesCore. rrule (:: typeof (Integrals. __solvebp), cache, alg, sensealg, lb, ub,
1516 p;
1617 kwargs... )
1718 out = Integrals. __solvebp_call (cache, alg, sensealg, lb, ub, p; kwargs... )
1819
20+ # the adjoint will be the integral of the input sensitivities, so it maps the
21+ # sensitivity of the output to an object of the type of the parameters
1922 function quadrature_adjoint (Δ)
20- y = typeof (Δ) <: Array{<:Number, 0} ? Δ[1 ] : Δ
23+ # https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes
24+ y = cache. nout == 1 ? Δ[1 ] : Δ # interpret the output as scalar
25+ # this will not be type-stable, but I believe it is unavoidable due to two ambiguities:
26+ # 1. Δ is the output of the algorithm, and when nout = 1 it is undefined whether the
27+ # output of the algorithm must be a scalar or a vector of length 1
28+ # 2. when nout = 1 the integrand can either be a scalar or a vector of length 1
2129 if isinplace (cache)
2230 dx = zeros (cache. nout)
2331 _f = x -> cache. f (dx, x, p)
2432 if sensealg. vjp isa Integrals. ZygoteVJP
2533 dfdp = function (dx, x, p)
26- _, back = Zygote. pullback (p) do p
27- _dx = Zygote. Buffer (x, cache. nout, size (x, 2 ))
34+ z, back = Zygote. pullback (p) do p
35+ _dx = cache. nout == 1 ?
36+ Zygote. Buffer (dx, eltype (y), size (x, ndims (x))) :
37+ Zygote. Buffer (dx, eltype (y), cache. nout, size (x, ndims (x)))
2838 cache. f (_dx, x, p)
2939 copy (_dx)
3040 end
31-
32- z = zeros ( size (x, 2 ))
33- for idx in 1 : size (x, 2 )
34- z[ 1 ] = 1
35- dx[:, idx] = back (z)[ 1 ]
36- z[ idx] = 0
41+ z . = zero ( eltype (z))
42+ for idx in 1 : size (x, ndims (x ))
43+ z isa Vector ? (z[ idx] = y) : (z[:, idx] . = y )
44+ dx[:, idx] . = back (z)[ 1 ]
45+ z isa Vector ? (z[ idx] = zero ( eltype (z))) :
46+ (z[:, idx] . = zero ( eltype (z)))
3747 end
3848 end
3949 elseif sensealg. vjp isa Integrals. ReverseDiffVJP
@@ -44,14 +54,21 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
4454 if sensealg. vjp isa Integrals. ZygoteVJP
4555 if cache. batch > 0
4656 dfdp = function (x, p)
47- _, back = Zygote. pullback (p -> cache. f (x, p), p)
57+ z, back = Zygote. pullback (p -> cache. f (x, p), p)
58+ # messy, there are 4 cases, some better in forward mode than reverse
59+ # 1: length(y) == 1 and length(p) == 1
60+ # 2: length(y) > 1 and length(p) == 1
61+ # 3: length(y) == 1 and length(p) > 1
62+ # 4: length(y) > 1 and length(p) > 1
4863
49- out = zeros (length (p), size (x, 2 ))
50- z = zeros (size (x, 2 ))
51- for idx in 1 : size (x, 2 )
52- z[idx] = 1
53- out[:, idx] = back (z)[1 ]
54- z[idx] = 0
64+ z .= zero (eltype (z))
65+ out = zeros (eltype (p), size (p)... , size (x, ndims (x)))
66+ for idx in 1 : size (x, ndims (x))
67+ z isa Vector ? (z[idx] = y) : (z[:, idx] .= y)
68+ out isa Vector ? (out[idx] = back (z)[1 ]) :
69+ (out[:, idx] .= back (z)[1 ])
70+ z isa Vector ? (z[idx] = zero (y)) :
71+ (z[:, idx] .= zero (eltype (y)))
5572 end
5673 out
5774 end
@@ -76,17 +93,30 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
7693 do_inf_transformation = Val (false ),
7794 cache. kwargs... )
7895
79- if p isa Number
80- dp = Integrals. __solvebp_call (dp_cache, alg, sensealg, lb, ub, p; kwargs... )[1 ]
81- else
82- dp = Integrals. __solvebp_call (dp_cache, alg, sensealg, lb, ub, p; kwargs... ). u
83- end
96+ project_p = ProjectTo (p)
97+ dp = project_p (Integrals. __solvebp_call (dp_cache,
98+ alg,
99+ sensealg,
100+ lb,
101+ ub,
102+ p;
103+ kwargs... ). u)
84104
85105 if lb isa Number
86- dlb = - _f (lb)
87- dub = _f (ub)
106+ dlb = cache . batch > 0 ? - _f ([lb]) : - _f (lb)
107+ dub = cache . batch > 0 ? _f ([ub]) : _f (ub)
88108 return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), dlb, dub, dp)
89109 else
110+ # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we
111+ # can see from writing the multidimensional integral as an iterated integral
112+ # alternatively we can use Stokes' theorem to replace the integral on the
113+ # boundary with a volume integral of the flux of the integrand
114+ # ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the
115+ # dimensionality of the integral or the quadrature used (such as quadratures
116+ # that don't evaluate points on the boundaries) and it could be generalized to
117+ # other kinds of domains. The only question is to determine ω in terms of f and
118+ # the deformation of the surface (e.g. consider integral over an ellipse and
119+ # asking for the derivative of the result w.r.t. the semiaxes of the ellipse)
90120 return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (),
91121 NoTangent (), dp)
92122 end
0 commit comments