Skip to content

Commit 64091bc

Browse files
committed
Introduce DifferentiationInterface
1 parent d3c68c2 commit 64091bc

7 files changed

Lines changed: 43 additions & 27 deletions

File tree

Manifest.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.3"
44
manifest_format = "2.0"
5-
project_hash = "7591f8aefbf83678f96ece6aa98d0232f3c083ff"
5+
project_hash = "f1f6d1cb011ce63d93f72a0bbd984e2839b32e42"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "fb97701c117c8162e84dfcf80215caa904aef44f"
@@ -2083,7 +2083,7 @@ weakdeps = ["Distributed"]
20832083
DistributedExt = "Distributed"
20842084

20852085
[[deps.Ribasim]]
2086-
deps = ["ADTypes", "Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "EnumX", "FiniteDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LineSearches", "LinearAlgebra", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEqBDF", "OrdinaryDiffEqCore", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqTsit5", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
2086+
deps = ["ADTypes", "Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "DifferentiationInterface", "EnumX", "FiniteDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LineSearches", "LinearAlgebra", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEqBDF", "OrdinaryDiffEqCore", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqTsit5", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
20872087
path = "core"
20882088
uuid = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
20892089
version = "2025.1.0"

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1818
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1919
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2020
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
21+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
2122
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2223
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
2324
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
@@ -57,6 +58,7 @@ SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9"
5758
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
5859
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5960
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
61+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
6062
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
6163
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
6264
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

core/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1717
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1818
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1919
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
20+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
2021
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2122
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2223
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -42,6 +43,7 @@ SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9"
4243
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
4344
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4445
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
46+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
4547
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
4648
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
4749
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
@@ -63,6 +65,7 @@ DataStructures = "0.18"
6365
Dates = "1"
6466
DiffEqBase = "6.155"
6567
DiffEqCallbacks = "3.6, 4"
68+
DifferentiationInterface = "0.6.42"
6669
EnumX = "1.0"
6770
FiniteDiff = "2.21"
6871
Graphs = "1.9"
@@ -89,6 +92,7 @@ SQLite = "1.5.1"
8992
SciMLBase = "2.36"
9093
SparseArrays = "1"
9194
SparseConnectivityTracer = "0.6.8"
95+
SparseMatrixColorings = "0.4.13"
9296
Statistics = "1"
9397
StructArrays = "0.6.13, 0.7"
9498
TOML = "1"

core/src/Ribasim.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ For more granular access, see:
1414
"""
1515
module Ribasim
1616

17+
using ADTypes: AutoForwardDiff, AutoFiniteDiff
18+
using DifferentiationInterface: AutoSparse, prepare_jacobian, jacobian!
19+
using SparseMatrixColorings: sparsity_pattern
20+
1721
# Algorithms for solving ODEs.
1822
using OrdinaryDiffEqCore: OrdinaryDiffEqCore, get_du, AbstractNLSolver
1923
using DiffEqBase: DiffEqBase, calculate_residuals!
@@ -36,7 +40,7 @@ using SciMLBase:
3640

3741
# Automatically detecting the sparsity pattern of the Jacobian of water_balance!
3842
# through operator overloading
39-
using SparseConnectivityTracer: TracerSparsityDetector, jacobian_sparsity, GradientTracer
43+
using SparseConnectivityTracer: GradientTracer
4044

4145
# For efficient sparse computations
4246
using SparseArrays: SparseMatrixCSC, spzeros

core/src/config.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using DataStructures: DefaultDict
1313
using Dates: DateTime
1414
using Logging: LogLevel, Debug, Info, Warn, Error
1515
using ..Ribasim: Ribasim, isnode, nodetype
16-
using ADTypes: AutoForwardDiff, AutoFiniteDiff
1716
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqNewtonAdaptiveAlgorithm
1817
using OrdinaryDiffEqNonlinearSolve: NLNewton
1918
using OrdinaryDiffEqLowOrderRK: Euler, RK4
@@ -310,10 +309,6 @@ function algorithm(solver::Solver; u0 = [])::OrdinaryDiffEqAlgorithm
310309
kwargs[:step_limiter!] = Ribasim.limit_flow!
311310
end
312311

313-
if function_accepts_kwarg(algotype, :autodiff)
314-
kwargs[:autodiff] = solver.autodiff ? AutoForwardDiff() : AutoFiniteDiff()
315-
end
316-
317312
algotype(; kwargs...)
318313
end
319314

core/src/model.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,33 @@ struct Model{T}
2828
end
2929
end
3030

31+
function get_jac_eval(du::Vector, u::Vector, p::Parameters, solver::Solver)
32+
backend = if solver.autodiff
33+
AutoForwardDiff()
34+
else
35+
AutoFiniteDiff()
36+
end
37+
38+
if solver.sparse
39+
backend = AutoSparse(backend)
40+
end
41+
42+
t = 0.0
43+
44+
# Activate all nodes to catch all possible state dependencies
45+
p.all_nodes_active = true
46+
prep = prepare_jacobian((du, u) -> water_balance(du, u, p, t), du, backend, u)
47+
p.all_nodes_active = false
48+
49+
jac_prototype = sparsity_pattern(prep)
50+
51+
jac =
52+
(J, u, p, t) ->
53+
jacobian!((du, u) -> water_balance!(du, u, p, t), du, J, prep, backend, u)
54+
55+
return jac_prototype, jac
56+
end
57+
3158
function Model(config_path::AbstractString)::Model
3259
config = Config(config_path)
3360
if !valid_config(config)
@@ -159,17 +186,13 @@ function Model(config::Config)::Model
159186
tstops = sort(unique(vcat(tstops...)))
160187
adaptive, dt = convert_dt(config.solver.dt)
161188

162-
jac_prototype = if config.solver.sparse
163-
get_jac_prototype(du0, u0, parameters, t0)
164-
else
165-
nothing
166-
end
167-
RHS = ODEFunction(water_balance!; jac_prototype)
189+
jac_prototype, jac = get_jac_eval(du0, u0, parameters, config.solver)
190+
RHS = ODEFunction(water_balance!; jac_prototype, jac)
168191

169192
prob = ODEProblem(RHS, u0, timespan, parameters)
170193
@debug "Setup ODEProblem."
171194

172-
callback, saved = create_callbacks(parameters, config, u0, saveat)
195+
callback, saved = create_callbacks(parameters, config, saveat)
173196
@debug "Created callbacks."
174197

175198
# Run water_balance! before initializing the integrator. This is because

core/src/util.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -764,18 +764,6 @@ function relaxed_root(x, threshold)
764764
end
765765
end
766766

767-
function get_jac_prototype(du0, u0, p, t0)
768-
p.all_nodes_active = true
769-
jac_prototype = jacobian_sparsity(
770-
(du, u) -> water_balance!(du, u, p, t0),
771-
du0,
772-
u0,
773-
TracerSparsityDetector(),
774-
)
775-
p.all_nodes_active = false
776-
jac_prototype
777-
end
778-
779767
# Custom overloads
780768
reduction_factor(x::GradientTracer, ::Real) = x
781769
low_storage_factor_resistance_node(storage, q::GradientTracer, inflow_id, outflow_id) = q

0 commit comments

Comments
 (0)