-
-
Notifications
You must be signed in to change notification settings - Fork 79
Closed
SciML/DiffEqBase.jl
#529Description
We are currently experimenting with time dependent parameters, but the gradients often seem to come out wrong. For instance, this here is an artificially simple example for clarity:
using DiffEqSensitivity, OrdinaryDiffEq, Zygote
function get_param(breakpoints, values, t)
for (i, tᵢ) in enumerate(breakpoints)
if t <= tᵢ
return values[i]
end
end
return values[end]
end
function fiip(du, u, p, t)
a = get_param([1., 2., 3.], p[1:4], t)
du[1] = dx = a * u[1] - u[1] * u[2]
du[2] = dy = -a * u[2] + u[1] * u[2]
end
p = [1., 1., 1., 1.]; u0 = [1.0;1.0]
prob = ODEProblem(fiip, u0, (0.0, 4.0), p);
Zygote.gradient(p->sum(concrete_solve(prob, Tsit5(), u0, p, sensealg = ForwardDiffSensitivity(), saveat = 0.1)), p)
Zygote.gradient(p->sum(concrete_solve(prob, Tsit5(), u0, p, sensealg = ForwardSensitivity(), saveat = 0.1)), p)
Zygote.gradient(p->sum(concrete_solve(prob, Tsit5(), u0, p, saveat = 0.1)), p)which outputs
([29.75558216432594, 10.206643764088701, 53.37700890093469, 3.5509327396481574],)
with ForwardDiffSensitivity
([33.975936715110464, 39.48130754134167, 18.942236116902354, 4.377628387002488],)
with ForwardSensitivity
and
([0.0, 0.0, 0.0, 96.77707801785176],)
with the default sensealg.
Metadata
Metadata
Assignees
Labels
No labels