Skip to content

Commit 2aa4dd2

Browse files
committed
Some progress on FFJORD
1 parent 323aee1 commit 2aa4dd2

21 files changed

Lines changed: 488 additions & 535 deletions

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,4 @@ explore various ways to integrate the two methodologies:
7272
- `Flux` is no longer re-exported from `DiffEqFlux`. Instead we reexport `Lux`.
7373
- `NeuralDAE` now allows an optional `du0` as input.
7474
- `TensorLayer` is now a Lux Neural Network.
75+
- APIs for quite a few layer constructions have changed. Please refer to the updated documentation for more details.

docs/make.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
using Documenter, DiffEqFlux
22

3-
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml", force = true)
4-
cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
3+
cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml"; force = true)
4+
cp("./docs/Project.toml", "./docs/src/assets/Project.toml"; force = true)
55

66
ENV["GKSwstype"] = "100"
77
using Plots
88
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
99

1010
include("pages.jl")
1111

12-
makedocs(sitename = "DiffEqFlux.jl",
12+
makedocs(; sitename = "DiffEqFlux.jl",
1313
authors = "Chris Rackauckas et al.",
1414
clean = true, doctest = false, linkcheck = true,
1515
warnonly = [:docs_block, :missing_docs],
1616
modules = [DiffEqFlux],
17-
format = Documenter.HTML(assets = ["assets/favicon.ico"],
17+
format = Documenter.HTML(; assets = ["assets/favicon.ico"],
1818
canonical = "https://docs.sciml.ai/DiffEqFlux/stable/"),
1919
pages = pages)
2020

21-
deploydocs(repo = "github.com/SciML/DiffEqFlux.jl.git";
21+
deploydocs(; repo = "github.com/SciML/DiffEqFlux.jl.git",
2222
push_preview = true)

docs/src/examples/GPUs.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ u0 = Float32[2.0; 0.0] |> gpu
2323
prob_gpu = ODEProblem(dudt, u0, tspan, ps)
2424

2525
# Runs on a GPU
26-
sol_gpu = solve(prob_gpu, Tsit5(), saveat = tsteps)
26+
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)
2727
```
2828

2929
Or we could directly use the neural ODE layer function, like:
3030

3131
```julia
3232
using DiffEqFlux: NeuralODE
33-
prob_neuralode_gpu = NeuralODE(model, tspan, Tsit5(), saveat = tsteps)
33+
prob_neuralode_gpu = NeuralODE(model, tspan, Tsit5(); saveat = tsteps)
3434
```
3535

3636
If one is using `Lux.Chain`, then the computation takes place on the GPU with
@@ -55,13 +55,13 @@ tsteps = 0.0f0:1.0f-1:10.0f0
5555
prob_gpu = ODEProblem(dudt2_, u0, tspan, p)
5656

5757
# Runs on a GPU
58-
sol_gpu = solve(prob_gpu, Tsit5(), saveat = tsteps)
58+
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)
5959
```
6060

6161
or via the NeuralODE struct:
6262

6363
```julia
64-
prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
64+
prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
6565
prob_neuralode_gpu(u0, p, st)
6666
```
6767

@@ -83,22 +83,22 @@ rng = Random.default_rng()
8383
u0 = Float32[2.0; 0.0]
8484
datasize = 30
8585
tspan = (0.0f0, 1.5f0)
86-
tsteps = range(tspan[1], tspan[2], length = datasize)
86+
tsteps = range(tspan[1], tspan[2]; length = datasize)
8787
function trueODEfunc(du, u, p, t)
8888
true_A = [-0.1 2.0; -2.0 -0.1]
8989
du .= ((u .^ 3)'true_A)'
9090
end
9191
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
9292
# Make the data into a GPU-based array if the user has a GPU
93-
ode_data = gpu(solve(prob_trueode, Tsit5(), saveat = tsteps))
93+
ode_data = gpu(solve(prob_trueode, Tsit5(); saveat = tsteps))
9494

9595
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
9696
u0 = Float32[2.0; 0.0] |> gpu
9797
p, st = Lux.setup(rng, dudt2)
9898
p = p |> ComponentArray |> gpu
9999
st = st |> gpu
100100

101-
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
101+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
102102

103103
function predict_neuralode(p)
104104
gpu(first(prob_neuralode(u0, p, st)))
@@ -119,8 +119,8 @@ callback = function (p, l, pred; doplot = false)
119119
iter += 1
120120
display(l)
121121
# plot current prediction against data
122-
plt = scatter(tsteps, Array(ode_data[1, :]), label = "data")
123-
scatter!(plt, tsteps, Array(pred[1, :]), label = "prediction")
122+
plt = scatter(tsteps, Array(ode_data[1, :]); label = "data")
123+
scatter!(plt, tsteps, Array(pred[1, :]); label = "prediction")
124124
push!(list_plots, plt)
125125
if doplot
126126
display(plot(plt))
@@ -132,7 +132,7 @@ adtype = Optimization.AutoZygote()
132132
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
133133
optprob = Optimization.OptimizationProblem(optf, p)
134134
result_neuralode = Optimization.solve(optprob,
135-
Adam(0.05),
135+
Adam(0.05);
136136
callback = callback,
137137
maxiters = 300)
138138
```

docs/src/examples/augmented_neural_ode.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ function concentric_sphere(dim, inner_radius_range, outer_radius_range,
2727
push!(data, reshape(random_point_in_sphere(dim, outer_radius_range...), :, 1))
2828
push!(labels, -ones(1, 1))
2929
end
30-
data = cat(data..., dims = 2)
31-
labels = cat(labels..., dims = 2)
30+
data = cat(data...; dims = 2)
31+
labels = cat(labels...; dims = 2)
3232
DataLoader((data |> Flux.gpu, labels |> Flux.gpu); batchsize = batch_size,
3333
shuffle = true,
3434
partial = false)
@@ -41,7 +41,7 @@ function construct_model(out_dim, input_dim, hidden_dim, augment_dim)
4141
node = NeuralODE(Flux.Chain(Flux.Dense(input_dim, hidden_dim, relu),
4242
Flux.Dense(hidden_dim, hidden_dim, relu),
4343
Flux.Dense(hidden_dim, input_dim)) |> Flux.gpu,
44-
(0.0f0, 1.0f0), Tsit5(), save_everystep = false,
44+
(0.0f0, 1.0f0), Tsit5(); save_everystep = false,
4545
reltol = 1.0f-3, abstol = 1.0f-3, save_start = false) |> Flux.gpu
4646
node = augment_dim == 0 ? node : AugmentedNDELayer(node, augment_dim)
4747
return Flux.Chain((x, p = node.p) -> node(x, p),
@@ -54,15 +54,15 @@ end
5454
function plot_contour(model, npoints = 300)
5555
grid_points = zeros(Float32, 2, npoints^2)
5656
idx = 1
57-
x = range(-4.0f0, 4.0f0, length = npoints)
58-
y = range(-4.0f0, 4.0f0, length = npoints)
57+
x = range(-4.0f0, 4.0f0; length = npoints)
58+
y = range(-4.0f0, 4.0f0; length = npoints)
5959
for x1 in x, x2 in y
6060
grid_points[:, idx] .= [x1, x2]
6161
idx += 1
6262
end
6363
sol = reshape(model(grid_points |> Flux.gpu), npoints, npoints) |> Flux.cpu
6464
65-
return contour(x, y, sol, fill = true, linewidth = 0.0)
65+
return contour(x, y, sol; fill = true, linewidth = 0.0)
6666
end
6767
6868
loss_node(x, y) = mean((model(x) .- y) .^ 2)
@@ -91,7 +91,7 @@ opt = Adam(0.005)
9191
println("Training Neural ODE")
9292
9393
for _ in 1:10
94-
Flux.train!(loss_node, Flux.params(parameters, model), dataloader, opt, cb = cb)
94+
Flux.train!(loss_node, Flux.params(parameters, model), dataloader, opt; cb = cb)
9595
end
9696
9797
plt_node = plot_contour(model)
@@ -103,7 +103,7 @@ println()
103103
println("Training Augmented Neural ODE")
104104
105105
for _ in 1:10
106-
Flux.train!(loss_node, Flux.params(parameters, model), dataloader, opt, cb = cb)
106+
Flux.train!(loss_node, Flux.params(parameters, model), dataloader, opt; cb = cb)
107107
end
108108
109109
plt_anode = plot_contour(model)
@@ -153,8 +153,8 @@ function concentric_sphere(dim, inner_radius_range, outer_radius_range,
153153
push!(data, reshape(random_point_in_sphere(dim, outer_radius_range...), :, 1))
154154
push!(labels, -ones(1, 1))
155155
end
156-
data = cat(data..., dims = 2)
157-
labels = cat(labels..., dims = 2)
156+
data = cat(data...; dims = 2)
157+
labels = cat(labels...; dims = 2)
158158
return DataLoader((data |> Flux.gpu, labels |> Flux.gpu); batchsize = batch_size,
159159
shuffle = true,
160160
partial = false)
@@ -181,7 +181,7 @@ function construct_model(out_dim, input_dim, hidden_dim, augment_dim)
181181
node = NeuralODE(Flux.Chain(Flux.Dense(input_dim, hidden_dim, relu),
182182
Flux.Dense(hidden_dim, hidden_dim, relu),
183183
Flux.Dense(hidden_dim, input_dim)) |> Flux.gpu,
184-
(0.0f0, 1.0f0), Tsit5(), save_everystep = false,
184+
(0.0f0, 1.0f0), Tsit5(); save_everystep = false,
185185
reltol = 1.0f-3, abstol = 1.0f-3, save_start = false) |> Flux.gpu
186186
node = augment_dim == 0 ? node : (AugmentedNDELayer(node, augment_dim) |> Flux.gpu)
187187
return Flux.Chain((x, p = node.p) -> node(x, p),
@@ -200,15 +200,15 @@ Here, we define a utility to plot our model regression results as a heatmap.
200200
function plot_contour(model, npoints = 300)
201201
grid_points = zeros(2, npoints^2)
202202
idx = 1
203-
x = range(-4.0f0, 4.0f0, length = npoints)
204-
y = range(-4.0f0, 4.0f0, length = npoints)
203+
x = range(-4.0f0, 4.0f0; length = npoints)
204+
y = range(-4.0f0, 4.0f0; length = npoints)
205205
for x1 in x, x2 in y
206206
grid_points[:, idx] .= [x1, x2]
207207
idx += 1
208208
end
209209
sol = reshape(model(grid_points |> Flux.gpu), npoints, npoints) |> Flux.cpu
210210
211-
return contour(x, y, sol, fill = true, linewidth = 0.0)
211+
return contour(x, y, sol; fill = true, linewidth = 0.0)
212212
end
213213
```
214214

@@ -269,7 +269,7 @@ for `20` epochs.
269269
model, parameters = construct_model(1, 2, 64, 0)
270270
271271
for _ in 1:10
272-
Flux.train!(loss_node, Flux.params(model, parameters), dataloader, opt, cb = cb)
272+
Flux.train!(loss_node, Flux.params(model, parameters), dataloader, opt; cb = cb)
273273
end
274274
```
275275

@@ -288,7 +288,7 @@ a function which can be expressed by the neural ode. For more details and proofs
288288
model, parameters = construct_model(1, 2, 64, 1)
289289
290290
for _ in 1:10
291-
Flux.train!(loss_node, Flux.params(model, parameters), dataloader, opt, cb = cb)
291+
Flux.train!(loss_node, Flux.params(model, parameters), dataloader, opt; cb = cb)
292292
end
293293
```
294294

docs/src/examples/collocation.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ rng = Random.default_rng()
2020
u0 = Float32[2.0; 0.0]
2121
datasize = 300
2222
tspan = (0.0f0, 1.5f0)
23-
tsteps = range(tspan[1], tspan[2], length = datasize)
23+
tsteps = range(tspan[1], tspan[2]; length = datasize)
2424
2525
function trueODEfunc(du, u, p, t)
2626
true_A = [-0.1 2.0; -2.0 -0.1]
2727
du .= ((u .^ 3)'true_A)'
2828
end
2929
3030
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
31-
data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) .+ 0.1randn(2, 300)
31+
data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) .+ 0.1randn(2, 300)
3232
3333
du, u = collocate_data(data, tsteps, EpanechnikovKernel())
3434
3535
scatter(tsteps, data')
36-
plot!(tsteps, u', lw = 5)
36+
plot!(tsteps, u'; lw = 5)
3737
savefig("colloc.png")
3838
plot(tsteps, du')
3939
savefig("colloc_du.png")
@@ -63,11 +63,11 @@ optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
6363
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
6464
6565
result_neuralode = Optimization.solve(optprob,
66-
Adam(0.05),
66+
Adam(0.05);
6767
callback = callback,
6868
maxiters = 10000)
6969
70-
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
70+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
7171
nn_sol, st = prob_neuralode(u0, result_neuralode.u, st)
7272
scatter(tsteps, data')
7373
plot!(nn_sol)
@@ -88,13 +88,13 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
8888
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
8989
9090
numerical_neuralode = Optimization.solve(optprob,
91-
Adam(0.05),
91+
Adam(0.05);
9292
callback = callback,
9393
maxiters = 300)
9494
9595
nn_sol, st = prob_neuralode(u0, numerical_neuralode.u, st)
9696
scatter(tsteps, data')
97-
plot!(nn_sol, lw = 5)
97+
plot!(nn_sol; lw = 5)
9898
```
9999

100100
## Generating the Collocation
@@ -112,20 +112,20 @@ rng = Random.default_rng()
112112
u0 = Float32[2.0; 0.0]
113113
datasize = 300
114114
tspan = (0.0f0, 1.5f0)
115-
tsteps = range(tspan[1], tspan[2], length = datasize)
115+
tsteps = range(tspan[1], tspan[2]; length = datasize)
116116
117117
function trueODEfunc(du, u, p, t)
118118
true_A = [-0.1 2.0; -2.0 -0.1]
119119
du .= ((u .^ 3)'true_A)'
120120
end
121121
122122
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
123-
data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) .+ 0.1randn(2, 300)
123+
data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) .+ 0.1randn(2, 300)
124124
125125
du, u = collocate_data(data, tsteps, EpanechnikovKernel())
126126
127127
scatter(tsteps, data')
128-
plot!(tsteps, u', lw = 5)
128+
plot!(tsteps, u'; lw = 5)
129129
```
130130

131131
We can then differentiate the smoothed function to get estimates of the
@@ -165,11 +165,11 @@ optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
165165
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
166166
167167
result_neuralode = Optimization.solve(optprob,
168-
Adam(0.05),
168+
Adam(0.05);
169169
callback = callback,
170170
maxiters = 10000)
171171
172-
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
172+
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
173173
nn_sol, st = prob_neuralode(u0, result_neuralode.u, st)
174174
scatter(tsteps, data')
175175
plot!(nn_sol)
@@ -196,13 +196,13 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
196196
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
197197
198198
numerical_neuralode = Optimization.solve(optprob,
199-
Adam(0.05),
199+
Adam(0.05);
200200
callback = callback,
201201
maxiters = 300)
202202
203203
nn_sol, st = prob_neuralode(u0, numerical_neuralode.u, st)
204204
scatter(tsteps, data')
205-
plot!(nn_sol, lw = 5)
205+
plot!(nn_sol; lw = 5)
206206
```
207207

208208
This method then has a good global starting position, making it less

docs/src/examples/hamiltonian_nn.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Now we make some simplifying assumptions, and assign ``m = 1`` and ``k = 1``. An
1212
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
1313
ComponentArrays, Optimization, OptimizationOptimisers, IterTools
1414
15-
t = range(0.0f0, 1.0f0, length = 1024)
15+
t = range(0.0f0, 1.0f0; length = 1024)
1616
π_32 = Float32(π)
1717
q_t = reshape(sin.(2π_32 * t), 1, :)
1818
p_t = reshape(cos.(2π_32 * t), 1, :)
@@ -46,12 +46,12 @@ res = Optimization.solve(opt_prob, opt, dataloader)
4646
4747
ps_trained = res.u
4848
49-
model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep = false,
49+
model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
5050
save_start = true, saveat = t)
5151
5252
pred = Array(first(model(data[:, 1], ps_trained, st)))
53-
plot(data[1, :], data[2, :], lw = 4, label = "Original")
54-
plot!(pred[1, :], pred[2, :], lw = 4, label = "Predicted")
53+
plot(data[1, :], data[2, :]; lw = 4, label = "Original")
54+
plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted")
5555
xlabel!("Position (q)")
5656
ylabel!("Momentum (p)")
5757
```
@@ -66,15 +66,15 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we
6666
using Flux, DiffEqFlux, DifferentialEquations, Statistics, Plots, ReverseDiff, Random,
6767
IterTools, Lux, ComponentArrays, Optimization, OptimizationOptimisers
6868
69-
t = range(0.0f0, 1.0f0, length = 1024)
69+
t = range(0.0f0, 1.0f0; length = 1024)
7070
π_32 = Float32(π)
7171
q_t = reshape(sin.(2π_32 * t), 1, :)
7272
p_t = reshape(cos.(2π_32 * t), 1, :)
7373
dqdt = 2π_32 .* p_t
7474
dpdt = -2π_32 .* q_t
7575
76-
data = cat(q_t, p_t, dims = 1)
77-
target = cat(dqdt, dpdt, dims = 1)
76+
data = cat(q_t, p_t; dims = 1)
77+
target = cat(dqdt, dpdt; dims = 1)
7878
B = 256
7979
NEPOCHS = 500
8080
dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
@@ -117,12 +117,12 @@ ps_trained = res.u
117117
In order to visualize the learned trajectories, we need to solve the ODE. We will use the `NeuralHamiltonianDE` layer, which is essentially a wrapper over `HamiltonianNN` layer, and solves the ODE.
118118

119119
```@example hamiltonian
120-
model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep = false,
120+
model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
121121
save_start = true, saveat = t)
122122
123123
pred = Array(first(model(data[:, 1], ps_trained, st)))
124-
plot(data[1, :], data[2, :], lw = 4, label = "Original")
125-
plot!(pred[1, :], pred[2, :], lw = 4, label = "Predicted")
124+
plot(data[1, :], data[2, :]; lw = 4, label = "Original")
125+
plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted")
126126
xlabel!("Position (q)")
127127
ylabel!("Momentum (p)")
128128
```

0 commit comments

Comments
 (0)