-
Notifications
You must be signed in to change notification settings - Fork 82
Add SciML integration tests #2593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/test/integration/SciML/runtests.jl b/test/integration/SciML/runtests.jl
index 7c655e40..631ab6a9 100644
--- a/test/integration/SciML/runtests.jl
+++ b/test/integration/SciML/runtests.jl
@@ -9,47 +9,49 @@ using LinearSolve, LinearAlgebra
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
- _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
+ _saveat = SA[0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]
function f_dt(y::Array{Float64}, u0::Array{Float64})
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
- sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
- y .= sol[1,:]
+ sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol = 1.0e-12, reltol = 1.0e-12)
+ y .= sol[1, :]
return nothing
- end;
+ end
function f_dt(u0)
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
- sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
- sol[1,:]
- end;
+ sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol = 1.0e-12, reltol = 1.0e-12)
+ sol[1, :]
+ end
u0 = [1.0; 0.0; 0.0]
fdj = ForwardDiff.jacobian(f_dt, u0)
- ezj = stack(map(1:3) do i
- d_u0 = zeros(3)
- dy = zeros(13)
- y = zeros(13)
- d_u0[i] = 1.0
- Enzyme.autodiff(Forward, f_dt, Duplicated(y, dy), Duplicated(u0, d_u0));
- dy
- end)
+ ezj = stack(
+ map(1:3) do i
+ d_u0 = zeros(3)
+ dy = zeros(13)
+ y = zeros(13)
+ d_u0[i] = 1.0
+ Enzyme.autodiff(Forward, f_dt, Duplicated(y, dy), Duplicated(u0, d_u0))
+ dy
+ end
+ )
@test ezj ≈ fdj
function f_dt2(u0)
tspan = (0.0, 3.0)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
- sol = DiffEqBase.solve(prob, Tsit5(), dt=0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
- sum(sol[1,:])
+ sol = DiffEqBase.solve(prob, Tsit5(), dt = 0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol = 1.0e-12, reltol = 1.0e-12)
+ sum(sol[1, :])
end
fdg = ForwardDiff.gradient(f_dt2, u0)
d_u0 = zeros(3)
- Enzyme.autodiff(Reverse, f_dt2, Active, Duplicated(u0, d_u0));
+ Enzyme.autodiff(Reverse, f_dt2, Active, Duplicated(u0, d_u0))
@test d_u0 ≈ fdg
end
@@ -61,7 +63,7 @@ struct senseloss0{T}
end
function (f::senseloss0)(u0p)
prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2])
- sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1))
+ return sum(solve(prob, Tsit5(), abstol = 1.0e-12, reltol = 1.0e-12, saveat = 0.1))
end
@testset "SciMLSensitivity Adjoint Interface" begin
@@ -73,12 +75,12 @@ end
@test du0p ≈ dup
end
- @testset "LinearSolve Adjoints" begin
+@testset "LinearSolve Adjoints" begin
n = 4
- A = rand(n, n);
- dA = zeros(n, n);
- b1 = rand(n);
- db1 = zeros(n);
+ A = rand(n, n)
+ dA = zeros(n, n)
+ b1 = rand(n)
+ db1 = zeros(n)
function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1) |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2593 +/- ##
=======================================
Coverage 75.10% 75.10%
=======================================
Files 56 56
Lines 17724 17724
=======================================
Hits 13312 13312
Misses 4412 4412 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
|
|
BTW, it looks like this test just regressed in the latest release 😅 |
|
https://github.com/EnzymeAD/Enzyme.jl/actions/runs/17889611301/job/50873551207?pr=2593#step:9:11 (only in v1.11)
That's why we need these tests! |
|
Yes that's the expected failure from the latest release |
89d5d94 to
f9b339e
Compare
|
the above error is now fixed, however now it hits: which I really need an MWE for to figure out/fix |
f9b339e to
48b95aa
Compare
|
approve to see? locally it's just the testset. |
|
yeah I don't know why that case doesn't work with Enzyme inside of a testset, but that is the case 🤷 |
test/integration/SciML/runtests.jl
Outdated
| #@testset "SciMLSensitivity Adjoint Interface" begin | ||
| Enzyme.API.typeWarning!(false) | ||
|
|
||
| odef(du, u, p, t) = du .= u .* p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if instead you put the function defns outside of the testset (but kept the testset)
|
Seems to work like this locally. Not sure why Enzyme is sensitive here while Zygote isn't but I think it's ignorable. |
|
probably closure causing sadness |
|
done |
No description provided.