Skip to content

Commit 434e752

Browse files
Merge pull request #477 from avik-pal/ap/bvp
Revisiting Boundary Value Problems
2 parents eb92995 + 593e7f0 commit 434e752

File tree

5 files changed

+140
-125
lines changed

5 files changed

+140
-125
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "1.98.1"
4+
version = "2.0.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/problems/bvp_problems.jl

Lines changed: 64 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ $(TYPEDEF)
33
"""
44
struct StandardBVProblem end
55

6+
"""
7+
$(TYPEDEF)
8+
"""
9+
struct TwoPointBVProblem end
10+
611
@doc doc"""
712
813
Defines an BVP problem.
@@ -17,7 +22,7 @@ condition ``u_0`` which define an ODE:
1722
\frac{du}{dt} = f(u,p,t)
1823
```
1924
20-
along with an implicit function `bc!` which defines the residual equation, where
25+
along with an implicit function `bc` which defines the residual equation, where
2126
2227
```math
2328
bc(u,p,t) = 0
@@ -36,22 +41,27 @@ u(t_f) = b
3641
### Constructors
3742
3843
```julia
39-
TwoPointBVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...)
40-
BVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...)
44+
TwoPointBVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
45+
BVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
4146
```
4247
4348
or if we have an initial guess function `initialGuess(t)` for the given BVP,
4449
we can pass the initial guess to the problem constructors:
4550
4651
```julia
47-
TwoPointBVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...)
48-
BVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...)
52+
TwoPointBVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
53+
BVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
4954
```
5055
51-
For any BVP problem type, `bc!` is the inplace function:
56+
For any BVP problem type, `bc` must be inplace if `f` is inplace. Otherwise it must be
57+
out-of-place.
58+
59+
If the bvp is a StandardBVProblem (also known as a Multi-Point BV Problem) it must define
60+
either of the following functions
5261
5362
```julia
5463
bc!(residual, u, p, t)
64+
residual = bc(u, p, t)
5565
```
5666
5767
where `residual` computed from the current `u`. `u` is an array of solution values
@@ -61,6 +71,16 @@ time points, and for shooting type methods `u=sol` the ODE solution.
6171
Note that all features of the `ODESolution` are present in this form.
6272
In both cases, the size of the residual matches the size of the initial condition.
6373
74+
If the bvp is a TwoPointBVProblem it must define either of the following functions
75+
76+
```julia
77+
bc!((resid_a, resid_b), (u_a, u_b), p)
78+
resid_a, resid_b = bc((u_a, u_b), p)
79+
```
80+
81+
where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are
82+
the solution values at the two endpoints, and `p` are the parameters.
83+
6484
Parameters are optional, and if not given, then a `NullParameters()` singleton
6585
will be used which will throw nice errors if you try to index non-existent
6686
parameters. Any extra keyword arguments are passed on to the solvers. For example,
@@ -88,16 +108,20 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
88108
problem_type::PT
89109
kwargs::K
90110

91-
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan,
92-
p = NullParameters(),
93-
problem_type = StandardBVProblem();
94-
kwargs...) where {iip}
111+
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan,
112+
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
95113
_tspan = promote_tspan(tspan)
96114
warn_paramtype(p)
97-
new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p),
98-
typeof(f), typeof(f.bc),
99-
typeof(problem_type), typeof(kwargs)}(f, f.bc, u0, _tspan, p,
100-
problem_type, kwargs)
115+
prob_type = TP ? TwoPointBVProblem() : StandardBVProblem()
116+
# Needed to ensure that `problem_type` doesn't get passed in kwargs
117+
if problem_type === nothing
118+
problem_type = prob_type
119+
else
120+
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
121+
end
122+
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc),
123+
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type,
124+
kwargs)
101125
end
102126

103127
function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
@@ -107,52 +131,43 @@ end
107131

108132
TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2
109133

110-
function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
111-
BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
134+
function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
135+
iip = isinplace(f, 4)
136+
return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
112137
end
113138

114-
function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
115-
BVProblem(BVPFunction(f, bc), u0, tspan, p; kwargs...)
139+
function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
140+
return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
116141
end
117142

118-
"""
119-
$(TYPEDEF)
120-
"""
121-
struct TwoPointBVPFunction{bF}
122-
bc::bF
143+
# This is mostly a fake stuct and isn't used anywhere
144+
# But we need it for function calls like TwoPointBVProblem{iip}(...) = ...
145+
struct TwoPointBVPFunction{iip} end
146+
147+
@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true)
148+
@inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip}
149+
return BVPFunction{iip}(args...; kwargs..., twopoint=true)
123150
end
124-
TwoPointBVPFunction(; bc = error("No argument bc")) = TwoPointBVPFunction(bc)
125-
(f::TwoPointBVPFunction)(residual, ua, ub, p) = f.bc(residual, ua, ub, p)
126-
(f::TwoPointBVPFunction)(residual, u, p) = f.bc(residual, u[1], u[end], p)
127151

128-
"""
129-
$(TYPEDEF)
130-
"""
131-
struct TwoPointBVProblem{iip} end
132-
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
133-
iip = isinplace(f, 4)
134-
TwoPointBVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
152+
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
153+
bcresid_prototype=nothing, kwargs...)
154+
return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p;
155+
kwargs...)
135156
end
136-
function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters();
137-
kwargs...) where {iip}
138-
BVProblem{iip}(f, TwoPointBVPFunction(bc), u0, tspan, p; kwargs...)
157+
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
158+
p = NullParameters(); kwargs...) where {iip, twopoint}
159+
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
160+
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
139161
end
140162

141163
# Allow previous timeseries solution
142-
function TwoPointBVProblem(f::AbstractODEFunction,
143-
bc,
144-
sol::T,
145-
tspan::Tuple,
146-
p = NullParameters()) where {T <: AbstractTimeseriesSolution}
147-
TwoPointBVProblem(f, bc, sol.u, tspan, p)
164+
function TwoPointBVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple,
165+
p = NullParameters(); kwargs...) where {T <: AbstractTimeseriesSolution}
166+
return TwoPointBVProblem(f, bc, sol.u, tspan, p; kwargs...)
148167
end
149168
# Allow initial guess function for the initial guess
150-
function TwoPointBVProblem(f::AbstractODEFunction,
151-
bc,
152-
initialGuess,
153-
tspan::AbstractVector,
154-
p = NullParameters();
155-
kwargs...)
169+
function TwoPointBVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector,
170+
p = NullParameters(); kwargs...)
156171
u0 = [initialGuess(i) for i in tspan]
157-
TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p)
172+
return TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p; kwargs...)
158173
end

src/scimlfunctions.jl

Lines changed: 45 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,8 +2124,7 @@ TruncatedStacktraces.@truncate_stacktrace OptimizationFunction 1 2
21242124
"""
21252125
$(TYPEDEF)
21262126
"""
2127-
abstract type AbstractBVPFunction{iip} <:
2128-
AbstractDiffEqFunction{iip} end
2127+
abstract type AbstractBVPFunction{iip, twopoint} <: AbstractDiffEqFunction{iip} end
21292128

21302129
@doc doc"""
21312130
BVPFunction{iip,F,BF,TMM,Ta,Tt,TJ,BCTJ,JVP,VJP,JP,BCJP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV,BCTCV} <: AbstractBVPFunction{iip,specialize}
@@ -2230,11 +2229,9 @@ For more details on this argument, see the ODEFunction documentation.
22302229
22312230
The fields of the BVPFunction type directly match the names of the inputs.
22322231
"""
2233-
struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP,
2234-
BCJP, SP, TW, TWt,
2235-
TPJ,
2236-
S, S2, S3, O, TCV, BCTCV,
2237-
SYS} <: AbstractBVPFunction{iip}
2232+
struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
2233+
JP, BCJP, BCRP, SP, TW, TWt, TPJ, S, S2, S3, O, TCV, BCTCV,
2234+
SYS} <: AbstractBVPFunction{iip, twopoint}
22382235
f::F
22392236
bc::BF
22402237
mass_matrix::TMM
@@ -2246,6 +2243,7 @@ struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP,
22462243
vjp::VJP
22472244
jac_prototype::JP
22482245
bcjac_prototype::BCJP
2246+
bcresid_prototype::BCRP
22492247
sparsity::SP
22502248
Wfact::TW
22512249
Wfact_t::TWt
@@ -3648,9 +3646,8 @@ function NonlinearFunction{iip, specialize}(f;
36483646
nothing,
36493647
sys = __has_sys(f) ? f.sys : nothing,
36503648
resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing) where {
3651-
iip,
3652-
specialize,
3653-
}
3649+
iip, specialize}
3650+
36543651
if mass_matrix === I && typeof(f) <: Tuple
36553652
mass_matrix = ((I for i in 1:length(f))...,)
36563653
end
@@ -3814,35 +3811,28 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
38143811
cons_expr, sys)
38153812
end
38163813

3817-
function BVPFunction{iip, specialize}(f, bc;
3818-
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
3819-
I,
3814+
function BVPFunction{iip, specialize, twopoint}(f, bc;
3815+
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
38203816
analytic = __has_analytic(f) ? f.analytic : nothing,
38213817
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
38223818
jac = __has_jac(f) ? f.jac : nothing,
38233819
bcjac = __has_jac(bc) ? bc.jac : nothing,
38243820
jvp = __has_jvp(f) ? f.jvp : nothing,
38253821
vjp = __has_vjp(f) ? f.vjp : nothing,
3826-
jac_prototype = __has_jac_prototype(f) ?
3827-
f.jac_prototype :
3828-
nothing,
3829-
bcjac_prototype = __has_jac_prototype(bc) ?
3830-
bc.jac_prototype :
3831-
nothing,
3832-
sparsity = __has_sparsity(f) ? f.sparsity :
3833-
jac_prototype,
3822+
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
3823+
bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing,
3824+
bcresid_prototype = nothing,
3825+
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
38343826
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
38353827
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
38363828
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
38373829
syms = __has_syms(f) ? f.syms : nothing,
38383830
indepsym = __has_indepsym(f) ? f.indepsym : nothing,
3839-
paramsyms = __has_paramsyms(f) ? f.paramsyms :
3840-
nothing,
3841-
observed = __has_observed(f) ? f.observed :
3842-
DEFAULT_OBSERVED,
3831+
paramsyms = __has_paramsyms(f) ? f.paramsyms : nothing,
3832+
observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED,
38433833
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
38443834
bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing,
3845-
sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize}
3835+
sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint}
38463836
if mass_matrix === I && typeof(f) <: Tuple
38473837
mass_matrix = ((I for i in 1:length(f))...,)
38483838
end
@@ -3882,7 +3872,7 @@ function BVPFunction{iip, specialize}(f, bc;
38823872
_bccolorvec = bccolorvec
38833873
end
38843874

3885-
bciip = isinplace(bc, 4, "bc", iip)
3875+
bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip)
38863876
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
38873877
bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip
38883878
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
@@ -3892,66 +3882,62 @@ function BVPFunction{iip, specialize}(f, bc;
38923882
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip
38933883
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip
38943884

3895-
nonconforming = (jaciip,
3896-
tgradiip,
3897-
jvpiip,
3898-
vjpiip,
3899-
Wfactiip,
3900-
Wfact_tiip,
3885+
nonconforming = (bciip, jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
39013886
paramjaciip) .!= iip
39023887
bc_nonconforming = bcjaciip .!= bciip
39033888
if any(nonconforming)
39043889
nonconforming = findall(nonconforming)
3905-
functions = ["jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming]
3890+
functions = ["bc", "jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t",
3891+
"paramjac"][nonconforming]
39063892
throw(NonconformingFunctionsError(functions))
39073893
end
39083894

3895+
if twopoint
3896+
if iip && (bcresid_prototype === nothing || length(bcresid_prototype) != 2)
3897+
error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction")
3898+
end
3899+
if bcresid_prototype !== nothing && length(bcresid_prototype) == 2
3900+
bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2])
3901+
end
3902+
end
3903+
39093904
if any(bc_nonconforming)
39103905
bc_nonconforming = findall(bc_nonconforming)
39113906
functions = ["bcjac"][bc_nonconforming]
39123907
throw(NonconformingFunctionsError(functions))
39133908
end
39143909

39153910
if specialize === NoSpecialize
3916-
BVPFunction{iip, specialize, Any, Any, Any, Any, Any,
3917-
Any, Any, Any, Any, Any, Any, Any, Any, Any,
3911+
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
3912+
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any,
39183913
Any, typeof(syms), typeof(indepsym), typeof(paramsyms),
39193914
Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix,
3920-
analytic,
3921-
tgrad,
3922-
jac, bcjac, jvp, vjp,
3923-
jac_prototype,
3924-
bcjac_prototype,
3925-
sparsity, Wfact,
3926-
Wfact_t,
3927-
paramjac, syms,
3928-
indepsym, paramsyms,
3929-
observed,
3915+
analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype,
3916+
bcjac_prototype, bcresid_prototype,
3917+
sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed,
39303918
_colorvec, _bccolorvec, sys)
39313919
else
3932-
BVPFunction{iip, specialize, typeof(f), typeof(bc), typeof(mass_matrix),
3933-
typeof(analytic),
3934-
typeof(tgrad),
3935-
typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
3936-
typeof(bcjac_prototype),
3937-
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
3938-
typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms),
3939-
typeof(observed),
3920+
BVPFunction{iip, specialize, twopoint, typeof(f), typeof(bc), typeof(mass_matrix),
3921+
typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp),
3922+
typeof(vjp), typeof(jac_prototype),
3923+
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
3924+
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms),
3925+
typeof(indepsym), typeof(paramsyms), typeof(observed),
39403926
typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(f, bc, mass_matrix, analytic,
39413927
tgrad, jac, bcjac, jvp, vjp,
3942-
jac_prototype, bcjac_prototype, sparsity,
3928+
jac_prototype, bcjac_prototype, bcresid_prototype, sparsity,
39433929
Wfact, Wfact_t, paramjac,
39443930
syms, indepsym, paramsyms, observed,
39453931
_colorvec, _bccolorvec, sys)
39463932
end
39473933
end
39483934

3949-
function BVPFunction{iip}(f, bc; kwargs...) where {iip}
3950-
BVPFunction{iip, FullSpecialize}(f, bc; kwargs...)
3935+
function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip}
3936+
BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...)
39513937
end
39523938
BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f
3953-
function BVPFunction(f, bc; kwargs...)
3954-
BVPFunction{isinplace(f, 4), FullSpecialize}(f, bc; kwargs...)
3939+
function BVPFunction(f, bc; twopoint::Bool=false, kwargs...)
3940+
BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...)
39553941
end
39563942
BVPFunction(f::BVPFunction; kwargs...) = f
39573943

test/downstream/ensemble_bvp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ tspan = (0.0, pi / 2)
1919
p = [rand()]
2020
bvp = BVProblem(ode!, bc!, initial_guess, tspan, p)
2121
ensemble_prob = EnsembleProblem(bvp, prob_func = prob_func)
22-
sim = solve(ensemble_prob, GeneralMIRK4(), trajectories = 10, dt = 0.1)
22+
sim = solve(ensemble_prob, MIRK4(), trajectories = 10, dt = 0.1)

0 commit comments

Comments
 (0)