Skip to content

Commit 823046c

Browse files
Merge pull request #175 from lxvm/autodiff
Fix AD for parameters
2 parents bb6a2a9 + 2d3ee8e commit 823046c

File tree

5 files changed

+107
-92
lines changed

5 files changed

+107
-92
lines changed

ext/IntegralsForwardDiffExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
2929
dfdp = function (out, x, p)
3030
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
3131
if cache.batch > 0
32-
dx = similar(dualp, cache.nout, size(x, 2))
32+
dx = cache.nout == 1 ? similar(dualp, size(x, ndims(x))) :
33+
similar(dualp, cache.nout, size(x, ndims(x)))
3334
else
3435
dx = similar(dualp, cache.nout)
3536
end
@@ -49,7 +50,7 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
4950
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
5051
ys = cache.f(x, dualp)
5152
if cache.batch > 0
52-
out = similar(p, V, nout, size(x, 2))
53+
out = similar(p, V, nout, size(x, ndims(x)))
5354
else
5455
out = similar(p, V, nout)
5556
end

ext/IntegralsZygoteExt.jl

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,47 @@ using Integrals
33
if isdefined(Base, :get_extension)
44
using Zygote
55
import ChainRulesCore
6-
import ChainRulesCore: NoTangent
6+
import ChainRulesCore: NoTangent, ProjectTo
77
else
88
using ..Zygote
99
import ..Zygote.ChainRulesCore
10-
import ..Zygote.ChainRulesCore: NoTangent
10+
import ..Zygote.ChainRulesCore: NoTangent, ProjectTo
1111
end
1212
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
13+
ChainRulesCore.@non_differentiable Integrals.isinplace(f, n) # fixes #99
1314

1415
function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub,
1516
p;
1617
kwargs...)
1718
out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)
1819

20+
# the adjoint will be the integral of the input sensitivities, so it maps the
21+
# sensitivity of the output to an object of the type of the parameters
1922
function quadrature_adjoint(Δ)
20-
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ
23+
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes
24+
y = cache.nout == 1 ? Δ[1] : Δ # interpret the output as scalar
25+
# this will not be type-stable, but I believe it is unavoidable due to two ambiguities:
26+
# 1. Δ is the output of the algorithm, and when nout = 1 it is undefined whether the
27+
# output of the algorithm must be a scalar or a vector of length 1
28+
# 2. when nout = 1 the integrand can either be a scalar or a vector of length 1
2129
if isinplace(cache)
2230
dx = zeros(cache.nout)
2331
_f = x -> cache.f(dx, x, p)
2432
if sensealg.vjp isa Integrals.ZygoteVJP
2533
dfdp = function (dx, x, p)
26-
_, back = Zygote.pullback(p) do p
27-
_dx = Zygote.Buffer(x, cache.nout, size(x, 2))
34+
z, back = Zygote.pullback(p) do p
35+
_dx = cache.nout == 1 ?
36+
Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) :
37+
Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x)))
2838
cache.f(_dx, x, p)
2939
copy(_dx)
3040
end
31-
32-
z = zeros(size(x, 2))
33-
for idx in 1:size(x, 2)
34-
z[1] = 1
35-
dx[:, idx] = back(z)[1]
36-
z[idx] = 0
41+
z .= zero(eltype(z))
42+
for idx in 1:size(x, ndims(x))
43+
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y)
44+
dx[:, idx] .= back(z)[1]
45+
z isa Vector ? (z[idx] = zero(eltype(z))) :
46+
(z[:, idx] .= zero(eltype(z)))
3747
end
3848
end
3949
elseif sensealg.vjp isa Integrals.ReverseDiffVJP
@@ -44,14 +54,21 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
4454
if sensealg.vjp isa Integrals.ZygoteVJP
4555
if cache.batch > 0
4656
dfdp = function (x, p)
47-
_, back = Zygote.pullback(p -> cache.f(x, p), p)
57+
z, back = Zygote.pullback(p -> cache.f(x, p), p)
58+
# messy, there are 4 cases, some better in forward mode than reverse
59+
# 1: length(y) == 1 and length(p) == 1
60+
# 2: length(y) > 1 and length(p) == 1
61+
# 3: length(y) == 1 and length(p) > 1
62+
# 4: length(y) > 1 and length(p) > 1
4863

49-
out = zeros(length(p), size(x, 2))
50-
z = zeros(size(x, 2))
51-
for idx in 1:size(x, 2)
52-
z[idx] = 1
53-
out[:, idx] = back(z)[1]
54-
z[idx] = 0
64+
z .= zero(eltype(z))
65+
out = zeros(eltype(p), size(p)..., size(x, ndims(x)))
66+
for idx in 1:size(x, ndims(x))
67+
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y)
68+
out isa Vector ? (out[idx] = back(z)[1]) :
69+
(out[:, idx] .= back(z)[1])
70+
z isa Vector ? (z[idx] = zero(y)) :
71+
(z[:, idx] .= zero(eltype(y)))
5572
end
5673
out
5774
end
@@ -76,17 +93,30 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
7693
do_inf_transformation = Val(false),
7794
cache.kwargs...)
7895

79-
if p isa Number
80-
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]
81-
else
82-
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u
83-
end
96+
project_p = ProjectTo(p)
97+
dp = project_p(Integrals.__solvebp_call(dp_cache,
98+
alg,
99+
sensealg,
100+
lb,
101+
ub,
102+
p;
103+
kwargs...).u)
84104

85105
if lb isa Number
86-
dlb = -_f(lb)
87-
dub = _f(ub)
106+
dlb = cache.batch > 0 ? -_f([lb]) : -_f(lb)
107+
dub = cache.batch > 0 ? _f([ub]) : _f(ub)
88108
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp)
89109
else
110+
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we
111+
# can see from writing the multidimensional integral as an iterated integral
112+
# alternatively we can use Stokes' theorem to replace the integral on the
113+
# boundary with a volume integral of the flux of the integrand
114+
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the
115+
# dimensionality of the integral or the quadrature used (such as quadratures
116+
# that don't evaluate points on the boundaries) and it could be generalized to
117+
# other kinds of domains. The only question is to determine ω in terms of f and
118+
# the deformation of the surface (e.g. consider integral over an ellipse and
119+
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse)
90120
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(),
91121
NoTangent(), dp)
92122
end

lib/IntegralsCubature/src/IntegralsCubature.jl

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
5454
maxiters = typemax(Int))
5555
nout = prob.nout
5656
if nout == 1
57+
# the output of prob.f could be either scalar or a vector of length 1, however
58+
# the behavior of the output of the integration routine is undefined (could differ
59+
# across algorithms)
60+
# Cubature will output a real number in when called without nout/fdim
5761
if prob.batch == 0
5862
if isinplace(prob)
5963
dx = zeros(eltype(lb), prob.nout)
@@ -63,74 +67,52 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
6367
end
6468
if lb isa Number
6569
if alg isa CubatureJLh
66-
_val, err = Cubature.hquadrature(f, lb, ub;
70+
val, err = Cubature.hquadrature(f, lb, ub;
6771
reltol = reltol, abstol = abstol,
6872
maxevals = maxiters)
6973
else
70-
_val, err = Cubature.pquadrature(f, lb, ub;
74+
val, err = Cubature.pquadrature(f, lb, ub;
7175
reltol = reltol, abstol = abstol,
7276
maxevals = maxiters)
7377
end
74-
val = prob.f(lb, p) isa Number ? _val : [_val]
7578
else
7679
if alg isa CubatureJLh
77-
_val, err = Cubature.hcubature(f, lb, ub;
80+
val, err = Cubature.hcubature(f, lb, ub;
7881
reltol = reltol, abstol = abstol,
7982
maxevals = maxiters)
8083
else
81-
_val, err = Cubature.pcubature(f, lb, ub;
84+
val, err = Cubature.pcubature(f, lb, ub;
8285
reltol = reltol, abstol = abstol,
8386
maxevals = maxiters)
8487
end
85-
86-
if isinplace(prob) || !isa(prob.f(lb, p), Number)
87-
val = [_val]
88-
else
89-
val = _val
90-
end
9188
end
9289
else
9390
if isinplace(prob)
94-
f = (x, dx) -> prob.f(dx', x, p)
95-
elseif lb isa Number
96-
if prob.f([lb ub], p) isa Vector
97-
f = (x, dx) -> (dx .= prob.f(x', p))
98-
else
99-
f = function (x, dx)
100-
dx[:] = prob.f(x', p)
101-
end
102-
end
91+
f = (x, dx) -> prob.f(dx, x, p)
10392
else
104-
if prob.f([lb ub], p) isa Vector
105-
f = (x, dx) -> (dx .= prob.f(x, p))
106-
else
107-
f = function (x, dx)
108-
dx .= prob.f(x, p)[:]
109-
end
110-
end
93+
f = (x, dx) -> (dx .= prob.f(x, p))
11194
end
11295
if lb isa Number
11396
if alg isa CubatureJLh
114-
_val, err = Cubature.hquadrature_v(f, lb, ub;
97+
val, err = Cubature.hquadrature_v(f, lb, ub;
11598
reltol = reltol, abstol = abstol,
11699
maxevals = maxiters)
117100
else
118-
_val, err = Cubature.pquadrature_v(f, lb, ub;
101+
val, err = Cubature.pquadrature_v(f, lb, ub;
119102
reltol = reltol, abstol = abstol,
120103
maxevals = maxiters)
121104
end
122105
else
123106
if alg isa CubatureJLh
124-
_val, err = Cubature.hcubature_v(f, lb, ub;
107+
val, err = Cubature.hcubature_v(f, lb, ub;
125108
reltol = reltol, abstol = abstol,
126109
maxevals = maxiters)
127110
else
128-
_val, err = Cubature.pcubature_v(f, lb, ub;
111+
val, err = Cubature.pcubature_v(f, lb, ub;
129112
reltol = reltol, abstol = abstol,
130113
maxevals = maxiters)
131114
end
132115
end
133-
val = _val isa Number ? [_val] : _val
134116
end
135117
else
136118
if prob.batch == 0
@@ -166,13 +148,9 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
166148
end
167149
else
168150
if isinplace(prob)
169-
f = (x, dx) -> prob.f(dx, x, p)
151+
f = (x, dx) -> (prob.f(dx, x, p); dx)
170152
else
171-
if lb isa Number
172-
f = (x, dx) -> (dx .= prob.f(x', p))
173-
else
174-
f = (x, dx) -> (dx .= prob.f(x, p))
175-
end
153+
f = (x, dx) -> (dx .= prob.f(x, p))
176154
end
177155

178156
if lb isa Number

test/derivative_tests.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Integrals, Zygote, FiniteDiff, ForwardDiff, SciMLSensitivity
1+
using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity
22
using IntegralsCuba, IntegralsCubature
33
using Test
44

@@ -117,7 +117,7 @@ dp4 = ForwardDiff.gradient(p -> testf(lb, ub, p), p)
117117
@test dp1 dp4
118118

119119
### Batch Single dim
120-
f(x, p) = x * p[1] .+ p[2] * p[3]
120+
f(x, p) = x * p[1] .+ p[2] * p[3] # scalar integrand
121121

122122
lb = 1.0
123123
ub = 3.0
@@ -130,14 +130,14 @@ function testf3(lb, ub, p; f = f)
130130
end
131131

132132
dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
133-
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)")
133+
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)")
134134
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)
135135

136136
@test dp1 dp3 #passes
137-
@test_broken dp2 dp3 #passes
137+
@test dp2 dp3 #passes
138138

139139
### Batch single dim, nout
140-
f(x, p) = (x * p[1] .+ p[2] * p[3]) .* [1; 2]
140+
f(x, p) = (x' * p[1] .+ p[2] * p[3]) .* [1; 2]
141141

142142
lb = 1.0
143143
ub = 3.0
@@ -150,11 +150,11 @@ function testf3(lb, ub, p; f = f)
150150
end
151151

152152
dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
153-
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
153+
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
154154
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)
155155

156156
@test dp1 dp3 #passes
157-
# @test dp2 ≈ dp3 #passes
157+
@test dp2 dp3 #passes
158158

159159
### Batch multi dim
160160
f(x, p) = x[1, :] * p[1] .+ p[2] * p[3]
@@ -190,15 +190,15 @@ function testf3(lb, ub, p; f = f)
190190
end
191191

192192
dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
193-
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
193+
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
194194
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)
195195

196196
@test dp1 dp3
197-
# @test dp2 ≈ dp3
197+
@test dp2 dp3
198198

199-
## iip Batch mulit dim
199+
## iip Batch multi dim
200200
function g(dx, x, p)
201-
dx .= sum(x * p[1] .+ p[2] * p[3], dims = 1)
201+
dx .= dropdims(sum(x * p[1] .+ p[2] * p[3], dims = 1), dims = 1)
202202
end
203203

204204
lb = [1.0, 1.0]
@@ -236,8 +236,8 @@ function testf3(lb, ub, p; f = g)
236236
end
237237

238238
dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
239-
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
239+
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
240240
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)
241241

242242
@test dp1 dp3
243-
# @test dp2 ≈ dp3
243+
@test dp2 dp3

0 commit comments

Comments
 (0)