diff --git a/.github/workflows/Integration.yml b/.github/workflows/Integration.yml index 2e2d5f1982..9568416a4c 100644 --- a/.github/workflows/Integration.yml +++ b/.github/workflows/Integration.yml @@ -50,6 +50,7 @@ jobs: - Distributions - DynamicExpressions - Lux + - SciML steps: - uses: actions/checkout@v5 - uses: julia-actions/setup-julia@v2 diff --git a/test/integration/SciML/Project.toml b/test/integration/SciML/Project.toml new file mode 100644 index 0000000000..7790b191a5 --- /dev/null +++ b/test/integration/SciML/Project.toml @@ -0,0 +1,27 @@ +[deps] +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[sources] +Enzyme = {path = "../../.."} +EnzymeCore = {path = "../../../lib/EnzymeCore"} + +[compat] +DiffEqBase = "6.190" +ForwardDiff = "0.10.36, 1" +LinearSolve = "3.12" +OrdinaryDiffEq = "6.89" +OrdinaryDiffEqTsit5 = "1.1" +SciMLSensitivity = "7.69" +StaticArrays = "1.9" +Zygote = "0.7.10" diff --git a/test/integration/SciML/runtests.jl b/test/integration/SciML/runtests.jl new file mode 100644 index 0000000000..7c655e40f1 --- /dev/null +++ b/test/integration/SciML/runtests.jl @@ -0,0 +1,100 @@ +using Enzyme, OrdinaryDiffEqTsit5, StaticArrays, DiffEqBase, ForwardDiff, Test +using OrdinaryDiffEq, SciMLSensitivity, Zygote +using LinearSolve, LinearAlgebra + +@testset "Direct Differentiation of Explicit ODE Solve" begin + function lorenz!(du, u, p, t) + du[1] = 10.0(u[2] - u[1]) + du[2] = u[1] * (28.0 - u[3]) - u[2] + du[3] = u[1] * u[2] - (8 / 3) * u[3] + end + + _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0] + + function f_dt(y::Array{Float64}, u0::Array{Float64}) + tspan = (0.0, 3.0) + prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan) + sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12) + y .= sol[1,:] + return nothing + end; + + function f_dt(u0) + tspan = (0.0, 3.0) + prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan) + sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12) + sol[1,:] + end; + + u0 = [1.0; 0.0; 0.0] + fdj = ForwardDiff.jacobian(f_dt, u0) + + ezj = stack(map(1:3) do i + d_u0 = zeros(3) + dy = zeros(13) + y = zeros(13) + d_u0[i] = 1.0 + Enzyme.autodiff(Forward, f_dt, Duplicated(y, dy), Duplicated(u0, d_u0)); + dy + end) + + @test ezj ≈ fdj + + function f_dt2(u0) + tspan = (0.0, 3.0) + prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan) + sol = DiffEqBase.solve(prob, Tsit5(), dt=0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12) + sum(sol[1,:]) + end + + fdg = ForwardDiff.gradient(f_dt2, u0) + d_u0 = zeros(3) + Enzyme.autodiff(Reverse, f_dt2, Active, Duplicated(u0, d_u0)); + + @test d_u0 ≈ fdg +end + +odef(du, u, p, t) = du .= u .* p +prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0]) +struct senseloss0{T} + sense::T +end +function (f::senseloss0)(u0p) + prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2]) + sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1)) +end + +@testset "SciMLSensitivity Adjoint Interface" begin + u0p = [2.0, 3.0] + du0p = zeros(2) + @test senseloss0(InterpolatingAdjoint())(u0p) isa Number + dup = Zygote.gradient(senseloss0(InterpolatingAdjoint()), u0p)[1] + Enzyme.autodiff(Reverse, senseloss0(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) + @test du0p ≈ dup +end + + @testset "LinearSolve Adjoints" begin + n = 4 + A = rand(n, n); + dA = zeros(n, n); + b1 = rand(n); + db1 = zeros(n); + + function f(A, b1; alg = LUFactorization()) + prob = LinearProblem(A, b1) + + sol1 = solve(prob, alg) + + s1 = sol1.u + norm(s1) + end + + f(A, b1) # Uses BLAS + + Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) + dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A)) + db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1)) + + @test dA ≈ dA2 + @test db1 ≈ db12 +end