Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
57ba954
add integrand interface
lxvm Sep 16, 2023
21b7895
add InplaceBatchIntegrand
lxvm Sep 16, 2023
6a2038a
format and include
lxvm Sep 17, 2023
3f77759
make the IntegralFunctions
lxvm Sep 19, 2023
f56d654
canonicalize
ChrisRackauckas Sep 19, 2023
e3ee453
Remove error checking on function definition of batch integral
ChrisRackauckas Sep 19, 2023
318e79f
add error test on incorrect integral function dispatches
ChrisRackauckas Sep 19, 2023
783b88e
argument amounts testing
ChrisRackauckas Sep 19, 2023
fcc7edb
some better utils checks
ChrisRackauckas Sep 19, 2023
5a37040
apply format
lxvm Sep 19, 2023
b02e470
fix integralfunction iip
lxvm Sep 19, 2023
95bdb1d
rename integrand_prototype to integral_prototype
lxvm Sep 19, 2023
5675e6f
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
774e4be
fix typos
ChrisRackauckas Sep 21, 2023
bbe691b
revert naming to integrand_prototype
lxvm Sep 21, 2023
c0f4062
wrap integrand with IntegralFunction in IntegralProblem
lxvm Sep 21, 2023
740576a
make integral functions callable
lxvm Sep 21, 2023
3f7d1fb
simplify IntegralProblem definition
lxvm Sep 21, 2023
0deeefb
update docstrings
lxvm Sep 21, 2023
e27965d
apply format
lxvm Sep 21, 2023
5be7d7a
remove output_prototype
lxvm Sep 21, 2023
8ebfe42
add deprecation method
lxvm Sep 21, 2023
e6a0547
Update src/problems/basic_problems.jl
ChrisRackauckas Sep 21, 2023
a3a09d4
Merge branch 'master' into integrands
ChrisRackauckas Sep 21, 2023
619ac07
Update test/function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
83f933d
Update function_building_error_messages.jl
ChrisRackauckas Sep 21, 2023
7740dd4
fix default batch and dispatch
lxvm Sep 21, 2023
a6fd63a
Change version just to run tests
ChrisRackauckas Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,14 @@ abstract type AbstractDiffEqFunction{iip} <:
"""
$(TYPEDEF)

Base for types defining integrand functions.
"""
abstract type AbstractIntegralFunction{iip} <:
AbstractSciMLFunction{iip} end

"""
$(TYPEDEF)

Base for types defining optimization functions.
"""
abstract type AbstractOptimizationFunction{iip} <: AbstractSciMLFunction{iip} end
Expand Down Expand Up @@ -659,7 +667,9 @@ function specialization(::Union{ODEFunction{iip, specialize},
RODEFunction{iip, specialize},
NonlinearFunction{iip, specialize},
OptimizationFunction{iip, specialize},
BVPFunction{iip, specialize}}) where {iip,
BVPFunction{iip, specialize},
IntegralFunction{iip, specialize},
BatchIntegralFunction{iip, specialize}}) where {iip,
specialize}
specialize
end
Expand Down Expand Up @@ -787,7 +797,8 @@ export remake

export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, DAEFunction,
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction
IncrementingODEFunction, NonlinearFunction, IntervalNonlinearFunction, BVPFunction,
IntegralFunction, BatchIntegralFunction

export OptimizationFunction

Expand Down
2 changes: 1 addition & 1 deletion src/problems/basic_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ struct SampledIntegralProblem{Y, X, D, K} <: AbstractIntegralProblem{false}
@add_kwonly function SampledIntegralProblem(y::AbstractArray, x::AbstractVector;
dim = ndims(y),
kwargs...)
@assert dim <= ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
@assert dim<=ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
@assert length(x)==size(y, dim) "The integrand `y` must have the same length as the sampling points `x` along the integrated dimension."
@assert axes(x, 1)==axes(y, dim) "The integrand `y` must obey the same indexing as the sampling points `x` along the integrated dimension."
new{typeof(y), typeof(x), Val{dim}, typeof(kwargs)}(y, x, Val(dim), kwargs)
Expand Down
189 changes: 188 additions & 1 deletion src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,27 @@ function Base.showerror(io::IO, e::NonconformingFunctionsError)
printstyled(io, e.nonconforming; bold = true, color = :red)
end

const INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE = """
Nonconforming functions detected. If an integrand function `f` is defined
as out-of-place (`f(u,p)`), then no integral_prototype can be passed into the
function constructor. Likewise if `f` is defined as in-place (`f(out,u,p)`), then
an integral_prototype is required. Either change the use of the function
constructor or define the appropriate dispatch for `f`.
"""

struct IntegrandMismatchFunctionError <: Exception
iip::Bool
integrand_passed::Bool
end

function Base.showerror(io::IO, e::IntegrandMismatchFunctionError)
println(io, INTEGRAND_MISMATCH_FUNCTIONS_ERROR_MESSAGE)
print(io, "Mismatch: IIP=")
printstyled(io, e.iip; bold = true, color = :red)
print(io, ", Integrand passed=")
printstyled(io, e.integrand_passed; bold = true, color = :red)
end

"""
$(TYPEDEF)
"""
Expand Down Expand Up @@ -2261,6 +2282,115 @@ end

TruncatedStacktraces.@truncate_stacktrace BVPFunction 1 2

@doc doc"""
IntegralFunction{iip,specialize,F} <: AbstractIntegralFunction{iip}

A representation of an integrand `f` defined by:

```math
f(u, p)
```

For an in-place form of `f` see the `iip` section below for details on in-place or out-of-place
handling.

```julia
IntegralFunction{iip,specialize}(f, [integral_prototype])
```

Note that only `f` is required, and in the case of inplace integrands a mutable container
`integral_prototype` to store the result of the integral. If `integral_prototype` is present,
`f` is interpreted as in-place, and otherwise `f` is assumed to be out-of-place.

## iip: In-Place vs Out-Of-Place

Out-of-place functions must be of the form ``f(u, p)`` and in-place functions of the form
``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or
arrays), in-place functions must provide a container `integral_prototype` that is of the
right type for the final result of the integral, and the result is written to this container
in-place. When in-place forms are used, in-place array operations may be used by algorithms
to reduce allocations. If `integral_prototype` is not provided, `f` is assumed to be
out-of-place and quadrature is performed assuming immutable return types.

## specialize

This field is currently unused

## Fields

The fields of the IntegralFunction type directly match the names of the inputs.
"""
struct IntegralFunction{iip, specialize, F, T} <:
AbstractIntegralFunction{iip}
f::F
integral_prototype::T
end

TruncatedStacktraces.@truncate_stacktrace IntegralFunction 1 2

@doc doc"""
BatchIntegralFunction{iip,specialize,F,T,Y,J,TJ,TPJ,Ta,S,JP,SP,TCV,O} <:
AbstractIntegralFunction{iip}

A representation of an integrand `f` that can be evaluated at multiple points simultaneously
using threads, the gpu, or distributed memory defined by:

```math
f(y, u, p)
```

``u`` is a vector whose elements correspond to distinct evaluation points to `f`, whose
output must be returned in the corresponding entries of ``y``. In general, the integration
algorithm is allowed to vary the number of evaluation points between subsequent calls to `f`

```julia
BatchIntegralFunction{iip,specialize}(f, output_prototype, [integral_prototype];
max_batch=typemax(Int))
```

Note that `f` is required and a `resize`-able buffer `output_prototype` to store the output,
or range of `f`, and in the case of inplace integrands a mutable container
`integral_prototype` to store the result of the integral. These buffers can be reused across
multiple compatible integrals to reduce allocations.

The keyword `max_batch` is used to set a soft limit on the number of points to batch at the
same time so that memory usage is controlled.

If `integral_prototype` is present, `f` is interpreted as in-place, and otherwise `f` is
assumed to be out-of-place.

## iip: In-Place vs Out-Of-Place

Out-of-place and in-place functions are both of the form ``f(y, u, p)``, but differ in the
element type of ``y``. Since `f` is allowed to return any type (e.g. real or complex numbers
or arrays), in-place functions must provide a container `integral_prototype` that is of the
right type for the final result of the integral, and the result is written to this container
in-place. When `f` is in-place, the buffer `output_prototype` is assumed to have a mutable
element type, and the last dimension of `output_prototype` should correspond to the batch
index. For example, `output_prototype` would have to be an `ElasticArray` or a
`VectorOfSimilarArrays` of an `ElasticArray`. When in-place forms are used, in-place array
operations may be used by algorithms to reduce allocations. If `integral_prototype` is not
provided, `f` is assumed to be out-of-place and quadrature is performed assuming
`output_prototype` is an `AbstractVector` with an immutable element type.

## specialize

This field is currently unused

## Fields

The fields of the BatchIntegralFunction type directly match the names of the inputs.
"""
struct BatchIntegralFunction{iip, specialize, F, Y, T} <:
AbstractIntegralFunction{iip}
f::F
output_prototype::Y
integral_prototype::T
max_batch::Int
end

TruncatedStacktraces.@truncate_stacktrace BatchIntegralFunction 1 2

######### Backwards Compatibility Overloads

(f::ODEFunction)(args...) = f.f(args...)
Expand Down Expand Up @@ -3955,6 +4085,61 @@ function BVPFunction(f, bc; kwargs...)
end
BVPFunction(f::BVPFunction; kwargs...) = f

function IntegralFunction{iip, specialize}(f, integral_prototype) where {iip, specialize}
IntegralFunction{iip, specialize, typeof(f), typeof(integral_prototype)}(f,
integral_prototype)
end

function IntegralFunction{iip}(f, integral_prototype) where {iip}
return IntegralFunction{iip, FullSpecialize}(f, integral_prototype)
end
function IntegralFunction(f)
calculated_iip = isinplace(f, 3, "integral", true)
if calculated_iip
throw(IntegrandMismatchFunctionError(calculated_iip, false))
end
IntegralFunction{false}(f, nothing)
end
function IntegralFunction(f, integral_prototype)
calcuated_iip = isinplace(f, 3, "integral", true)
if !calcuated_iip
throw(IntegrandMismatchFunctionError(calculated_iip, true))
end
IntegralFunction{true}(f, integral_prototype)
end

function BatchIntegralFunction{iip, specialize}(f, output_prototype, integral_prototype;
max_batch::Integer = typemax(Int)) where {iip, specialize}
BatchIntegralFunction{
iip,
specialize,
typeof(f),
typeof(output_prototype),
typeof(integral_prototype),
}(f,
output_prototype,
integral_prototype,
max_batch)
end

function BatchIntegralFunction{iip}(f,
output_prototype,
integral_prototype;
kwargs...) where {iip}
return BatchIntegralFunction{iip, FullSpecialize}(f,
output_prototype,
integral_prototype;
kwargs...)
end
function BatchIntegralFunction(f, output_prototype; kwargs...)
calcuated_iip = isinplace(f, 3, "batchintegral", true; has_two_dispatches = false)
BatchIntegralFunction{false}(f, output_prototype, nothing; kwargs...)
end
function BatchIntegralFunction(f, output_prototype, integral_prototype; kwargs...)
calcuated_iip = isinplace(f, 3, "batchintegral", true; has_two_dispatches = false)
BatchIntegralFunction{true}(f, output_prototype, integral_prototype; kwargs...)
end

########## Existence Functions

# Check that field/property exists (may be nothing)
Expand Down Expand Up @@ -4064,7 +4249,9 @@ for S in [:ODEFunction
:NonlinearFunction
:IntervalNonlinearFunction
:IncrementingODEFunction
:BVPFunction]
:BVPFunction
:IntegralFunction
:BatchIntegralFunction]
@eval begin
function ConstructionBase.constructorof(::Type{<:$S{iip}}) where {
iip,
Expand Down
32 changes: 31 additions & 1 deletion test/function_building_error_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ IntegralProblem(intf, [0.0], [1.0], p)
x = [1.0, 2.0]
y = rand(2, 2)
SampledIntegralProblem(y, x)
SampledIntegralProblem(y, x; dim=2)
SampledIntegralProblem(y, x; dim = 2)

# Optimization

Expand Down Expand Up @@ -601,3 +601,33 @@ BVPFunction(bfoop, bciip, vjp = bvjp)
bvjp(du, u, v, p, t) = [1.0]
BVPFunction(bfiip, bciip, vjp = bvjp)
BVPFunction(bfoop, bciip, vjp = bvjp)

# IntegralFunction

ioop(u, p) = p * u
iiip(y, u, p) = y .= u * p
i1(u) = u
itoo(y, u, p, a) = y .= u * p

IntegralFunction(ioop)
IntegralFunction(iiip, Float64[])

@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(ioop, Float64[])
@test_throws SciMLBase.IntegrandMismatchFunctionError IntegralFunction(i1)
@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(i1)
@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo)
@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(itoo, Float64[])

# BatchIntegralFunction

boop(y, u, p) = y .= p .* u
biip(y, u, p) = y .= p .* u # this example is not realistic
bi1(y, u) = y .= p .* u
bitoo(y, u, p, a) = y .= p .* u

BatchIntegralFunction(boop, Float64[])
BatchIntegralFunction(boop, Float64[], max_batch = 20)
BatchIntegralFunction(biip, Float64[], Float64[]) # the 2nd argument should be an ElasticArray
@test_throws SciMLBase.TooFewArgumentsError IntegralFunction(bi1)
@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(bitoo)
@test_throws SciMLBase.TooManyArgumentsError IntegralFunction(bitoo, Float64[])