diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 3494a9f1..00000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -style = "sciml" -format_markdown = true -format_docstrings = true diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml new file mode 100644 index 00000000..6762c6f3 --- /dev/null +++ b/.github/workflows/FormatCheck.yml @@ -0,0 +1,19 @@ +name: format-check + +on: + push: + branches: + - 'master' + - 'main' + - 'release-' + tags: '*' + pull_request: + +jobs: + runic: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: fredrikekre/runic-action@v1 + with: + version: '1' diff --git a/docs/make.jl b/docs/make.jl index 521ea93f..20f7e706 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -17,10 +17,10 @@ makedocs(; "Examples" => [ "examples/DiffEqFlux.md", "examples/adaptive_control.md", + "examples/ODE_jac.md", "examples/coulomb_control.md", - "examples/ODE_jac.md" ], - "API" => "api.md" + "API" => "api.md", ], repo = GitHub("SciML/ComponentArrays.jl"), sitename = "ComponentArrays.jl", diff --git a/examples/DiffEqFlux_example.jl b/examples/DiffEqFlux_example.jl index 0c481879..6ccd556d 100644 --- a/examples/DiffEqFlux_example.jl +++ b/examples/DiffEqFlux_example.jl @@ -23,12 +23,12 @@ tspan2 = (0.0f0, 25.0f0) # Make truth data function trueODEfunc(du, u, p, t) true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u .^ 3)'true_A)' + return du .= ((u .^ 3)'true_A)' end t = range(tspan[1], tspan[2], length = datasize) # t = Float32.(vcat(range(0.0, 0.9, length=10), 10 .^ range(log10(tspan[1]+1), log10(tspan[2]), length=datasize-10))) -t = Float32.([0; 10 .^ range(log10(tspan[1] + 0.01), log10(tspan[2]), length = datasize-1)]) +t = Float32.([0; 10 .^ range(log10(tspan[1] + 0.01), log10(tspan[2]), length = datasize - 1)]) prob = ODEProblem(trueODEfunc, u0, tspan2) ode_sol = solve(prob, Tsit5()) ode_data = Array(ode_sol(t)) + MeasurementNoise(0.1) @@ -39,7 +39,7 @@ neural_layer(in, out) = ComponentArray{Float32}(W = glorot_uniform(out, in), b = # Dense neural layer function dense(layer, activation = identity) = u -> activation.(layer.W * u + layer.b) -# Neural ODE function +# Neural ODE function dudt(u, p, t) = u .^ 3 |> dense(p.L1, σ) |> dense(p.L2) prob = ODEProblem(dudt, u0, tspan2) @@ -56,7 +56,7 @@ full_sol(θ) = solve(prob, Tsit5(), u0 = θ.u, p = θ.p) function loss_n_ode(θ) pred = predict_n_ode(θ) - loss = sum(abs2, ode_data .- pred)/datasize + 0.1*(sum(abs, θ.p)/length(θ.p)) + loss = sum(abs2, ode_data .- pred) / datasize + 0.1 * (sum(abs, θ.p) / length(θ.p)) return loss, pred end loss_n_ode(θ) @@ -83,12 +83,18 @@ cb = function (θ, loss, pred; doplot = false) plot!(pl_3, ode_sol, vars = (1, 2), label = "truth") scatter!(pl_3, pred[1, :], pred[2, :], label = "predicted data") scatter!(pl_3, ode_data[1, :], ode_data[2, :], label = "measured data") - plot!(pl_3, hcat(pred[1, :], ode_data[1, :])', hcat(pred[2, :], ode_data[2, :])', - label = false, color = :lightgray, legend = :bottomright) - - display(plot( - plot(pl_1, pl_2, layout = (2, 1), size = (400, 500)), pl_3, layout = (1, 2), size = ( - 950, 500))) + plot!( + pl_3, hcat(pred[1, :], ode_data[1, :])', hcat(pred[2, :], ode_data[2, :])', + label = false, color = :lightgray, legend = :bottomright + ) + + display( + plot( + plot(pl_1, pl_2, layout = (2, 1), size = (400, 500)), pl_3, layout = (1, 2), size = ( + 950, 500, + ) + ) + ) # frame(anim) return false end diff --git a/examples/ODE_example.jl b/examples/ODE_example.jl index 2d1f25c3..e54e0826 100644 --- a/examples/ODE_example.jl +++ b/examples/ODE_example.jl @@ -9,13 +9,13 @@ function lorenz!(D, u, p, t; f = 0.0) @unpack σ, ρ, β = p @unpack x, y, z = u - D.x = σ*(y - x) - D.y = x*(ρ - z) - y - f - D.z = x*y - β*z + D.x = σ * (y - x) + D.y = x * (ρ - z) - y - f + D.z = x * y - β * z return nothing end -lorenz_p = (σ = 10.0, ρ = 28.0, β = 8/3) +lorenz_p = (σ = 10.0, ρ = 28.0, β = 8 / 3) lorenz_ic = ComponentArray(x = 0.0, y = 0.0, z = 0.0) lorenz_prob = ODEProblem(lorenz!, lorenz_ic, tspan, lorenz_p) @@ -24,12 +24,12 @@ function lotka!(D, u, p, t; f = 0.0) @unpack α, β, γ, δ = p @unpack x, y = u - D.x = α*x - β*x*y + f - D.y = -γ*y + δ*x*y + D.x = α * x - β * x * y + f + D.y = -γ * y + δ * x * y return nothing end -lotka_p = (α = 2/3, β = 4/3, γ = 1.0, δ = 1.0) +lotka_p = (α = 2 / 3, β = 4 / 3, γ = 1.0, δ = 1.0) lotka_ic = ComponentArray(x = 1.0, y = 1.0) lotka_prob = ODEProblem(lotka!, lotka_ic, tspan, lotka_p) @@ -38,8 +38,8 @@ function composed!(D, u, p, t) c = p.c #coupling parameter @unpack lorenz, lotka = u - lorenz!(D.lorenz, lorenz, p.lorenz, t, f = c*lotka.x) - lotka!(D.lotka, lotka, p.lotka, t, f = c*lorenz.x) + lorenz!(D.lorenz, lorenz, p.lorenz, t, f = c * lotka.x) + lotka!(D.lotka, lotka, p.lotka, t, f = c * lorenz.x) return nothing end diff --git a/examples/ODE_jac_example.jl b/examples/ODE_jac_example.jl index d9065e5a..fc914b73 100644 --- a/examples/ODE_jac_example.jl +++ b/examples/ODE_jac_example.jl @@ -9,9 +9,9 @@ function lorenz!(D, u, p, t; f = 0.0) @unpack σ, ρ, β = p @unpack x, y, z = u - D.x = σ*(y - x) - D.y = x*(ρ - z) - y - f - D.z = x*y - β*z + D.x = σ * (y - x) + D.y = x * (ρ - z) - y - f + D.z = x * y - β * z return nothing end function lorenz_jac!(D, u, p, t) @@ -31,7 +31,7 @@ function lorenz_jac!(D, u, p, t) return nothing end -lorenz_p = (σ = 10.0, ρ = 28.0, β = 8/3) +lorenz_p = (σ = 10.0, ρ = 28.0, β = 8 / 3) lorenz_ic = ComponentArray(x = 0.0, y = 0.0, z = 0.0) lorenz_fun = ODEFunction(lorenz!, jac = lorenz_jac!) lorenz_prob = ODEProblem(lorenz_fun, lorenz_ic, tspan, lorenz_p) @@ -41,23 +41,23 @@ function lotka!(D, u, p, t; f = 0.0) @unpack α, β, γ, δ = p @unpack x, y = u - D.x = α*x - β*x*y + f - D.y = -γ*y + δ*x*y + D.x = α * x - β * x * y + f + D.y = -γ * y + δ * x * y return nothing end function lotka_jac!(D, u, p, t) @unpack α, β, γ, δ = p @unpack x, y = u - D[:x, :x] = α - β*y - D[:x, :y] = -β*x + D[:x, :x] = α - β * y + D[:x, :y] = -β * x - D[:y, :x] = δ*y - D[:y, :y] = -γ + δ*x + D[:y, :x] = δ * y + D[:y, :y] = -γ + δ * x return nothing end -lotka_p = (α = 2/3, β = 4/3, γ = 1.0, δ = 1.0) +lotka_p = (α = 2 / 3, β = 4 / 3, γ = 1.0, δ = 1.0) lotka_ic = ComponentArray(x = 1.0, y = 1.0) lotka_fun = ODEFunction(lotka!, jac = lotka_jac!) lotka_prob = ODEProblem(lotka_fun, lotka_ic, tspan, lotka_p) @@ -67,8 +67,8 @@ function composed!(D, u, p, t) c = p.c #coupling parameter @unpack lorenz, lotka = u - lorenz!(D.lorenz, lorenz, p.lorenz, t, f = c*lotka.x) - lotka!(D.lotka, lotka, p.lotka, t, f = c*lorenz.x) + lorenz!(D.lorenz, lorenz, p.lorenz, t, f = c * lotka.x) + lotka!(D.lotka, lotka, p.lotka, t, f = c * lorenz.x) return nothing end function composed_jac!(D, u, p, t) diff --git a/examples/adaptive_control_example.jl b/examples/adaptive_control_example.jl index a2b2933a..4132e3b4 100644 --- a/examples/adaptive_control_example.jl +++ b/examples/adaptive_control_example.jl @@ -12,8 +12,9 @@ maybe_apply(f, x, p, t) = f function apply_inputs(func; kwargs...) simfun(dx, x, p, t) = func( - dx, x, p, t; map(f->maybe_apply(f, x, p, t), (; kwargs...))...) - simfun(x, p, t) = func(x, p, t; map(f->maybe_apply(f, x, p, t), (; kwargs...))...) + dx, x, p, t; map(f -> maybe_apply(f, x, p, t), (; kwargs...))... + ) + simfun(x, p, t) = func(x, p, t; map(f -> maybe_apply(f, x, p, t), (; kwargs...))...) return simfun end @@ -27,7 +28,7 @@ SISO_simulator(P::TransferFunction) = SISO_simulator(ss(P)) function SISO_simulator(P::AbstractStateSpace) @unpack A, B, C, D = P - if size(D)!=(1, 1) + if size(D) != (1, 1) error("This is not a SISO system") end @@ -37,8 +38,8 @@ function SISO_simulator(P::AbstractStateSpace) DD = D[1, 1] return function sim!(dx, x, p, t; u = 0.0) - dx .= A*x + BB*u - return CC*x + DD*u + dx .= A * x + BB * u + return CC * x + DD * u end end @@ -60,13 +61,13 @@ nominal_sim! = SISO_simulator(nominal_plant) # To test robustness to uncertainty, we'll also include unmodeled dynamics with an entirely # different structure than our nominal plant model. -unmodeled_dynamics = 229/(s^2 + 30s + 229) +unmodeled_dynamics = 229 / (s^2 + 30s + 229) truth_plant = nominal_plant * unmodeled_dynamics truth_sim! = SISO_simulator(truth_plant) # We'll make a first-order sensor as well so we can add noise to our measurement τ = 0.005 -sensor_plant = 1 / (τ*s + 1) +sensor_plant = 1 / (τ * s + 1) sensor_sim! = SISO_simulator(sensor_plant) ## Derivative functions @@ -76,7 +77,7 @@ control(θ, w) = θ'w # We'll use a simple gradient descent adaptation law function adapt!(Dθ, θ, γ, t; e, w) - Dθ .= -γ*e*w + Dθ .= -γ * e * w return nothing end @@ -98,7 +99,7 @@ function feedback_sys!(D, vars, p, t; ym, r, n) return yp end # Now the full system takes in an input signal `r`, feeds it through the reference model, -# and feeds the output of the reference model `ym` and the input signal to `feedback_sys`. +# and feeds the output of the reference model `ym` and the input signal to `feedback_sys`. function system!(D, vars, p, t; r = 0.0, n = 0.0) @unpack reference_model, feedback_loop = vars @@ -122,16 +123,18 @@ sensor_ic = zeros(1) θ_est_ic = ComponentArray(θr = 0.0, θy = 0.0) ## Set up and run Simulation -function simulate(plant_fun, plant_ic; +function simulate( + plant_fun, plant_ic; tspan = tspan, input_signal = input_signal, adapt_gain = 1.5, noise_param = nothing, - deterministic_noise = 0.0) + deterministic_noise = 0.0 + ) noise(D, vars, p, t) = (D.feedback_loop.sensor[1] = noise_param) # Truth control parameters - θ_truth = (r = bm/bp, y = (ap-am)/bp) + θ_truth = (r = bm / bp, y = (ap - am) / bp) # Initial conditions ic = ComponentArray( @@ -139,14 +142,14 @@ function simulate(plant_fun, plant_ic; feedback_loop = ( parameter_estimates = θ_est_ic, sensor = sensor_ic, - plant_model = plant_ic + plant_model = plant_ic, ) ) # Model parameters p = ( gamma = adapt_gain, - plant_fun = plant_fun + plant_fun = plant_fun, ) sim_fun = apply_inputs(system!; r = input_signal, n = deterministic_noise) @@ -172,9 +175,14 @@ function simulate(plant_fun, plant_ic; ) # Parameter estimate tracking - bottom = plot(sol, - vars = Symbol.([ - "feedback_loop.parameter_estimates.θr", "feedback_loop.parameter_estimates.θy"])) + bottom = plot( + sol, + vars = Symbol.( + [ + "feedback_loop.parameter_estimates.θr", "feedback_loop.parameter_estimates.θy", + ] + ) + ) plot!( bottom, [tspan...], [θ_truth.r θ_truth.y; θ_truth.r θ_truth.y], @@ -184,8 +192,10 @@ function simulate(plant_fun, plant_ic; ) # Combine both plots - plot(top, bottom, layout = (2, 1), size = (800, 800)) + return plot(top, bottom, layout = (2, 1), size = (800, 800)) end -simulate(truth_sim!, truth_ic; input_signal = 2.0, - deterministic_noise = (x, p, t)->0.5sin(16.1t), noise_param = nothing) +simulate( + truth_sim!, truth_ic; input_signal = 2.0, + deterministic_noise = (x, p, t) -> 0.5sin(16.1t), noise_param = nothing +) diff --git a/ext/ComponentArraysGPUArraysExt.jl b/ext/ComponentArraysGPUArraysExt.jl index c222e668..e8ac1d4a 100644 --- a/ext/ComponentArraysGPUArraysExt.jl +++ b/ext/ComponentArraysGPUArraysExt.jl @@ -4,11 +4,13 @@ using ComponentArrays, LinearAlgebra, GPUArrays using ComponentArrays: recursive_eltype const GPUComponentArray = ComponentArray{ - T, N, <:GPUArrays.AbstractGPUArray, Ax} where {T, N, Ax} + T, N, <:GPUArrays.AbstractGPUArray, Ax, +} where {T, N, Ax} const GPUComponentVector{T, Ax} = ComponentArray{T, 1, <:GPUArrays.AbstractGPUVector, Ax} const GPUComponentMatrix{T, Ax} = ComponentArray{T, 2, <:GPUArrays.AbstractGPUMatrix, Ax} const GPUComponentVecorMat{ - T, Ax} = Union{GPUComponentVector{T, Ax}, GPUComponentMatrix{T, Ax}} + T, Ax, +} = Union{GPUComponentVector{T, Ax}, GPUComponentMatrix{T, Ax}} @static if pkgversion(GPUArrays) < v"11" GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x)) @@ -41,8 +43,13 @@ function Base.map(f, x::GPUComponentArray, args...) data = map(f, getdata(x), getdata.(args)...) return ComponentArray(data, getaxes(x)) end -function Base.map(f, x::GPUComponentArray, args::Vararg{Union{ - Base.AbstractBroadcasted, AbstractArray}}) +function Base.map( + f, x::GPUComponentArray, args::Vararg{ + Union{ + Base.AbstractBroadcasted, AbstractArray, + }, + } + ) data = map(f, getdata(x), map(getdata, args)...) return ComponentArray(data, getaxes(x)) end @@ -54,8 +61,10 @@ end function Base.mapreduce(f, op, x::GPUComponentArray, args...; kwargs...) return mapreduce(f, op, getdata(x), map(getdata, args)...; kwargs...) end -function Base.mapreduce(f, op, x::GPUComponentArray, - args::Vararg{Union{Base.AbstractBroadcasted, AbstractArray}}; kwargs...) +function Base.mapreduce( + f, op, x::GPUComponentArray, + args::Vararg{Union{Base.AbstractBroadcasted, AbstractArray}}; kwargs... + ) return mapreduce(f, op, getdata(x), map(getdata, args)...; kwargs...) end @@ -67,224 +76,311 @@ Base.any(f::Function, A::GPUComponentArray) = mapreduce(f, |, getdata(A)) Base.all(f::Function, A::GPUComponentArray) = mapreduce(f, &, getdata(A)) function Base.count(pred::Function, A::GPUComponentArray; dims = :, init = 0) - mapreduce(pred, Base.add_sum, getdata(A); init = init, dims = dims) + return mapreduce(pred, Base.add_sum, getdata(A); init = init, dims = dims) end # avoid calling into `initarray!` -for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), - (:maximum, :(Base.max)), (:minimum, :(Base.min)), - (:all, :&), (:any, :|)] +for (fname, op) in [ + (:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), + (:maximum, :(Base.max)), (:minimum, :(Base.min)), + (:all, :&), (:any, :|), + ] fname! = Symbol(fname, '!') @eval begin - Base.$(fname!)(f::Function, + Base.$(fname!)( + f::Function, r::GPUComponentArray, - A::GPUComponentArray{T}) where {T} = GPUArrays.mapreducedim!( - f, $(op), getdata(r), getdata(A); init = neutral_element($(op), T)) + A::GPUComponentArray{T} + ) where {T} = GPUArrays.mapreducedim!( + f, $(op), getdata(r), getdata(A); init = neutral_element($(op), T) + ) end end -function ComponentArrays.ComponentArray(nt::NamedTuple{names, - <:Tuple{Vararg{Union{GPUArrays.AbstractGPUArray, GPUComponentArray}}}}) where {names} +function ComponentArrays.ComponentArray( + nt::NamedTuple{ + names, + <:Tuple{Vararg{Union{GPUArrays.AbstractGPUArray, GPUComponentArray}}}, + } + ) where {names} T = recursive_eltype(nt) gpuarray = getdata(first(nt)) G = Base.typename(typeof(gpuarray)).wrapper # SciMLBase.parameterless_type(gpuarray) return GPUArrays.adapt(G, ComponentArray(NamedTuple{names}(map(GPUArrays.adapt(Array{T}), nt)))) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, - B::GPUComponentVecorMat, a::Number, b::Number) + B::GPUComponentVecorMat, a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, B::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, B::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, B::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat - }, a::Number, b::Number) + B::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, + }, a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Number, b::Number) + B::GPUComponentVecorMat, a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - B::GPUComponentVecorMat, a::Number, b::Number) + B::GPUComponentVecorMat, a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Number, b::Number) + B::GPUComponentVecorMat, a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, + A::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, }, B::GPUComponentVecorMat, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, + A::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, }, B::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - B::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat - }, a::Number, b::Number) + B::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, + }, a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, B::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) + a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, + A::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, }, - B::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat - }, a::Number, b::Number) + B::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, + }, a::Number, b::Number + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, - B::GPUComponentVecorMat, a::Real, b::Real) + B::GPUComponentVecorMat, a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, B::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) + b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, B::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - a::Real, b::Real) + a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat - }, a::Real, b::Real) + B::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, + }, a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Real, b::Real) + B::GPUComponentVecorMat, a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - B::GPUComponentVecorMat, a::Real, b::Real) + B::GPUComponentVecorMat, a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Real, b::Real) + B::GPUComponentVecorMat, a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, + A::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, }, B::GPUComponentVecorMat, - a::Real, b::Real) + a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) + b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, + A::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, }, B::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - a::Real, b::Real) + a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Real, b::Real) + a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - B::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat - }, a::Real, b::Real) + B::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, + }, a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) + b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, B::LinearAlgebra.Adjoint{<:Any, <:GPUComponentVecorMat}, - a::Real, b::Real) + a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, A::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:GPUArrays.AbstractGPUVecOrMat}, - a::Real, b::Real) + a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat +function LinearAlgebra.mul!( + C::GPUComponentVecorMat, + A::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, }, - B::LinearAlgebra.Transpose{<:Any, <:GPUComponentVecorMat - }, a::Real, b::Real) + B::LinearAlgebra.Transpose{ + <:Any, <:GPUComponentVecorMat, + }, a::Real, b::Real + ) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end diff --git a/ext/ComponentArraysReactantExt.jl b/ext/ComponentArraysReactantExt.jl index 19c234b1..7269f737 100644 --- a/ext/ComponentArraysReactantExt.jl +++ b/ext/ComponentArraysReactantExt.jl @@ -4,7 +4,7 @@ using ArrayInterface: ArrayInterface using ComponentArrays, Reactant const TracedComponentVector{T} = ComponentVector{ - Reactant.TracedRNumber{T}, <:Reactant.TracedRArray{T} + Reactant.TracedRNumber{T}, <:Reactant.TracedRArray{T}, } where {T} # Reactant is good at memory management but not great at handling wrapped types. So we avoid diff --git a/ext/ComponentArraysReverseDiffExt.jl b/ext/ComponentArraysReverseDiffExt.jl index c9b8b485..a2c6e3ad 100644 --- a/ext/ComponentArraysReverseDiffExt.jl +++ b/ext/ComponentArraysReverseDiffExt.jl @@ -3,10 +3,11 @@ module ComponentArraysReverseDiffExt using ComponentArrays, ReverseDiff const TrackedComponentArray{ - V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V, D, N, ComponentArray{V, N, A, Ax}, DA} + V, D, N, DA, A, Ax, +} = ReverseDiff.TrackedArray{V, D, N, ComponentArray{V, N, A, Ax}, DA} function maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) - ReverseDiff.TrackedArray(val, der, tape) + return ReverseDiff.TrackedArray(val, der, tape) end function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector) ax = getaxes(ReverseDiff.value(origin))[1] @@ -34,8 +35,12 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol) end end -function Base.propertynames(::TrackedComponentArray{V, D, N, DA, A, - Tuple{Ax}}) where {V, D, N, DA, A, Ax <: ComponentArrays.AbstractAxis} +function Base.propertynames( + ::TrackedComponentArray{ + V, D, N, DA, A, + Tuple{Ax}, + } + ) where {V, D, N, DA, A, Ax <: ComponentArrays.AbstractAxis} return propertynames(ComponentArrays.indexmap(Ax)) end @@ -47,6 +52,7 @@ end @inline ComponentArrays.__value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) @inline ComponentArrays.__value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) @inline ComponentArrays.__value(x::TrackedComponentArray) = ComponentArray( - ComponentArrays.__value(getdata(x)), getaxes(x)) + ComponentArrays.__value(getdata(x)), getaxes(x) +) end diff --git a/ext/ComponentArraysSciMLBaseExt.jl b/ext/ComponentArraysSciMLBaseExt.jl index b0a28941..9c81cfac 100644 --- a/ext/ComponentArraysSciMLBaseExt.jl +++ b/ext/ComponentArraysSciMLBaseExt.jl @@ -3,8 +3,11 @@ module ComponentArraysSciMLBaseExt using ComponentArrays, SciMLBase -function SciMLBase.getsyms(sol::SciMLBase.AbstractODESolution{ - T, N, C}) where {T, N, C <: AbstractVector{<:ComponentArray}} +function SciMLBase.getsyms( + sol::SciMLBase.AbstractODESolution{ + T, N, C, + } + ) where {T, N, C <: AbstractVector{<:ComponentArray}} if SciMLBase.has_syms(sol.prob.f) return sol.prob.f.syms else diff --git a/ext/ComponentArraysTrackerExt.jl b/ext/ComponentArraysTrackerExt.jl index bf47ff9c..816d284c 100644 --- a/ext/ComponentArraysTrackerExt.jl +++ b/ext/ComponentArraysTrackerExt.jl @@ -13,8 +13,12 @@ Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca)) Tracker.data(ca::ComponentArray) = ComponentArray(Tracker.data(getdata(ca)), getaxes(ca)) -function Base.materialize(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Nothing, - typeof(zero), <:Tuple{<:ComponentVector}}) +function Base.materialize( + bc::Base.Broadcast.Broadcasted{ + Tracker.TrackedStyle, Nothing, + typeof(zero), <:Tuple{<:ComponentVector}, + } + ) ca = first(bc.args) return ComponentArray(zero.(getdata(ca)), getaxes(ca)) end @@ -26,8 +30,10 @@ end # For TrackedArrays ignore Base.maybeview ## Tracker with views doesn't work quite well -@inline function Base.getproperty(x::ComponentVector{T, <:TrackedArray}, - s::Symbol) where {T} +@inline function Base.getproperty( + x::ComponentVector{T, <:TrackedArray}, + s::Symbol + ) where {T} return getproperty(x, Val(s)) end @@ -35,8 +41,10 @@ end return ComponentArrays._getindex(Base.getindex, x, v) end -function ArrayInterface.restructure(x::ComponentVector, - y::ComponentVector{T, <:TrackedArray}) where {T} +function ArrayInterface.restructure( + x::ComponentVector, + y::ComponentVector{T, <:TrackedArray} + ) where {T} getaxes(x) == getaxes(y) || error("Axes must match") return y end diff --git a/src/array_interface.jl b/src/array_interface.jl index 65940e01..7afc4d8f 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -2,7 +2,7 @@ Base.parent(x::ComponentArray) = getfield(x, :data) Base.size(x::ComponentArray) = size(getdata(x)) function StaticArrayInterface.static_size(A::ComponentArray) - StaticArrayInterface.static_size(parent(A)) + return StaticArrayInterface.static_size(parent(A)) end Base.elsize(x::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = Base.elsize(A) @@ -11,16 +11,26 @@ Base.elsize(x::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = Ba Base.axes(x::ComponentArray) = CombinedAxis.(getaxes(x), axes(getdata(x))) function Base.reinterpret(::Type{T}, x::ComponentArray, args...) where {T} - ComponentArray(reinterpret(T, getdata(x), args...), getaxes(x)) + return ComponentArray(reinterpret(T, getdata(x), args...), getaxes(x)) end -function ArrayInterface.indices_do_not_alias(::Type{ComponentArray{ - T, N, A, Axes}}) where {T, N, A, Axes} - ArrayInterface.indices_do_not_alias(A) +function ArrayInterface.indices_do_not_alias( + ::Type{ + ComponentArray{ + T, N, A, Axes, + }, + } + ) where {T, N, A, Axes} + return ArrayInterface.indices_do_not_alias(A) end -function ArrayInterface.instances_do_not_alias(::Type{ComponentArray{ - T, N, A, Axes}}) where {T, N, A, Axes} - ArrayInterface.instances_do_not_alias(A) +function ArrayInterface.instances_do_not_alias( + ::Type{ + ComponentArray{ + T, N, A, Axes, + }, + } + ) where {T, N, A, Axes} + return ArrayInterface.instances_do_not_alias(A) end # Cats @@ -28,7 +38,7 @@ end function Base.hcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) ax_x, ax_y = second_axis.((x, y)) if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init = false) || - getaxes(x)[1] != getaxes(y)[1] + getaxes(x)[1] != getaxes(y)[1] return hcat(getdata(x), getdata(y)) else data_x, data_y = getdata.((x, y)) @@ -36,8 +46,12 @@ function Base.hcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) idxmap_x, idxmap_y = indexmap.((ax_x, ax_y)) axs = getaxes(x) return ComponentArray( - hcat(data_x, data_y), axs[1], Axis((; - idxmap_x..., idxmap_y...)), axs[3:end]...) + hcat(data_x, data_y), axs[1], Axis( + (; + idxmap_x..., idxmap_y..., + ) + ), axs[3:end]... + ) end end @@ -62,7 +76,7 @@ end function Base.vcat(x::AbstractComponentVecOrMat{<:Number}, y::AbstractComponentVecOrMat{<:Number}) ax_x, ax_y = getindex.(getaxes.((x, y)), 1) if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init = false) || - getaxes(x)[2:end] != getaxes(y)[2:end] + getaxes(x)[2:end] != getaxes(y)[2:end] return vcat(getdata(x), getdata(y)) else data_x, data_y = getdata.((x, y)) @@ -72,16 +86,21 @@ function Base.vcat(x::AbstractComponentVecOrMat{<:Number}, y::AbstractComponentV end end function Base.vcat(x::CV...) where {CV <: AdjOrTransComponentArray{<:Number}} - ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1])) + return ComponentArray(reduce(vcat, map(y -> getdata(y.parent)', x)), getaxes(x[1])) end Base.vcat(x::ComponentVector{<:Number}, args...) = vcat(getdata(x), getdata.(args)...) -function Base.vcat(x::ComponentVector{<:Number}, - args::Vararg{Union{Number, UniformScaling, AbstractVecOrMat{<:Number}}}) - vcat(getdata(x), getdata.(args)...) +function Base.vcat( + x::ComponentVector{<:Number}, + args::Vararg{Union{Number, UniformScaling, AbstractVecOrMat{<:Number}}} + ) + return vcat(getdata(x), getdata.(args)...) end -function Base.vcat(x::ComponentVector{<:Number}, args::Vararg{ - AbstractVector{T}, N}) where {T <: Number, N} - vcat(getdata(x), getdata.(args)...) +function Base.vcat( + x::ComponentVector{<:Number}, args::Vararg{ + AbstractVector{T}, N, + } + ) where {T <: Number, N} + return vcat(getdata(x), getdata.(args)...) end function Base.hvcat(row_lengths::NTuple{N, Int}, xs::Vararg{AbstractComponentVecOrMat}) where {N} @@ -98,20 +117,23 @@ end function Base.permutedims(x::ComponentArray, dims) axs = getaxes(x) - return ComponentArray(permutedims(getdata(x), dims), map(i->axs[i], dims)...) + return ComponentArray(permutedims(getdata(x), dims), map(i -> axs[i], dims)...) end ## Indexing function Base.IndexStyle(::Type{<:ComponentArray{T, N, <:A, <:Axes}}) where {T, N, A, Axes} - IndexStyle(A) + return IndexStyle(A) end # Since we aren't really using the standard approach to indexing, this will forward things to # the correct methods Base.to_indices(x::ComponentArray, i::Tuple{Any}) = i -function Base.to_indices(x::ComponentArray, i::NTuple{ - N, Union{Integer, CartesianIndex}}) where {N} - i +function Base.to_indices( + x::ComponentArray, i::NTuple{ + N, Union{Integer, CartesianIndex}, + } + ) where {N} + return i end Base.to_indices(x::ComponentArray, i::NTuple{N, Int64}) where {N} = i Base.to_index(x::ComponentArray, i) = i @@ -131,20 +153,24 @@ Base.@propagate_inbounds Base.getindex(x::ComponentArray, ::Colon, ::Vararg{Colo # Set ComponentArray index Base.@propagate_inbounds Base.setindex!( - x::ComponentArray, v, idx::FlatOrColonIdx...) = setindex!(getdata(x), v, idx...) + x::ComponentArray, v, idx::FlatOrColonIdx... +) = setindex!(getdata(x), v, idx...) Base.@propagate_inbounds Base.setindex!(x::ComponentArray, v, ::Colon) = setindex!(getdata(x), v, :) @inline Base.setindex!(x::ComponentArray, v, idx...) = setindex!(x, v, toval.(idx)...) @inline Base.setindex!(x::ComponentArray, v, idx::Vararg{Val}) = _setindex!(x, v, idx...) # Explicitly view Base.@propagate_inbounds Base.view( - x::ComponentArray, idx::Vararg{ComponentArrays.FlatIdx}) = view(getdata(x), idx...) + x::ComponentArray, idx::Vararg{ComponentArrays.FlatIdx} +) = view(getdata(x), idx...) Base.@propagate_inbounds Base.view(x::ComponentArray, idx...) = _getindex(view, x, toval.(idx)...) Base.@propagate_inbounds Base.maybeview( - x::ComponentArray, idx::Vararg{ComponentArrays.FlatIdx}) = Base.maybeview(getdata(x), idx...) + x::ComponentArray, idx::Vararg{ComponentArrays.FlatIdx} +) = Base.maybeview(getdata(x), idx...) Base.@propagate_inbounds Base.maybeview( - x::ComponentArray, idx...) = _getindex(Base.maybeview, x, toval.(idx)...) + x::ComponentArray, idx... +) = _getindex(Base.maybeview, x, toval.(idx)...) # Generated get and set index methods to do all of the heavy lifting in the type domain @generated function _getindex(index_fun, x::ComponentArray, idx...) @@ -167,18 +193,26 @@ end ## Linear Algebra function Base.pointer(x::ComponentArray{T, N, A, Axes}) where {T, N, A <: DenseArray, Axes} - pointer(getdata(x)) + return pointer(getdata(x)) end -function Base.unsafe_convert(::Type{Ptr{T}}, x::ComponentArray{ - T, N, A, Axes}) where {T, N, A, Axes} - Base.unsafe_convert(Ptr{T}, getdata(x)) +function Base.unsafe_convert( + ::Type{Ptr{T}}, x::ComponentArray{ + T, N, A, Axes, + } + ) where {T, N, A, Axes} + return Base.unsafe_convert(Ptr{T}, getdata(x)) end Base.strides(x::ComponentArray) = strides(getdata(x)) for f in [:device, :stride_rank, :contiguous_axis, :contiguous_batch_size, :dense_dims] - @eval StaticArrayInterface.$f(::Type{ComponentArray{ - T, N, A, Axes}}) where {T, N, A, Axes} = StaticArrayInterface.$f(A) + @eval StaticArrayInterface.$f( + ::Type{ + ComponentArray{ + T, N, A, Axes, + }, + } + ) where {T, N, A, Axes} = StaticArrayInterface.$f(A) end Base.stride(x::ComponentArray, k) = stride(getdata(x), k) diff --git a/src/axis.jl b/src/axis.jl index 2d95ac11..3dfa8e49 100644 --- a/src/axis.jl +++ b/src/axis.jl @@ -94,7 +94,7 @@ struct PartitionedAxis{PartSz, IdxMap, Ax <: AbstractAxis{IdxMap}} <: AbstractAx end end function PartitionedAxis{PartSz, IdxMap, Ax}() where {PartSz, IdxMap, Ax} - PartitionedAxis(PartSz, Ax()) + return PartitionedAxis(PartSz, Ax()) end PartitionedAxis(PartSz, IdxMap) = PartitionedAxis(PartSz, Axis(IdxMap)) @@ -122,18 +122,25 @@ ViewAxis(Inds) = Inds Base.length(ax::ViewAxis{Inds}) where {Inds} = length(Inds) # Fix https://github.com/Deltares/Ribasim/issues/2028 -function Base.getindex(::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, - idx::Integer) where {Inds, IdxMap} - Inds[idx] +function Base.getindex( + ::ViewAxis{Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, + idx::Integer + ) where {Inds, IdxMap} + return Inds[idx] end -function Base.iterate(::ViewAxis{ - Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}) where {Inds, IdxMap} - iterate(Inds) +function Base.iterate( + ::ViewAxis{ + Inds, IdxMap, <:ComponentArrays.Shaped1DAxis, + } + ) where {Inds, IdxMap} + return iterate(Inds) end function Base.iterate( ::ViewAxis{ - Inds, IdxMap, <:ComponentArrays.Shaped1DAxis}, idx) where {Inds, IdxMap} - iterate(Inds, idx) + Inds, IdxMap, <:ComponentArrays.Shaped1DAxis, + }, idx + ) where {Inds, IdxMap} + return iterate(Inds, idx) end const View = ViewAxis @@ -154,9 +161,11 @@ Axis(x) = FlatAxis() const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, Shaped1DAxis} where {IdxMap} const NotPartitionedAxis = Union{ - Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}, Shaped1DAxis} where {Shape, IdxMap} + Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}, Shaped1DAxis, +} where {Shape, IdxMap} const NotShapedOrPartitionedAxis = Union{ - Axis{IdxMap}, FlatAxis, Shaped1DAxis} where {IdxMap} + Axis{IdxMap}, FlatAxis, Shaped1DAxis, +} where {IdxMap} Base.merge(axs::Vararg{Axis}) = Axis(merge(indexmap.(axs)...)) @@ -167,10 +176,12 @@ Base.keys(ax::AbstractAxis) = keys(indexmap(ax)) reindex(i, offset) = i .+ offset reindex(ax::FlatAxis, _) = ax -reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax))) +reindex(ax::Axis, offset) = Axis(map(x -> reindex(x, offset), indexmap(ax))) reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax)) -function reindex(ax::ViewAxis{OldInds, IdxMap, Ax}, - offset) where {OldInds, IdxMap, Ax <: Union{Shaped1DAxis, ShapedAxis}} +function reindex( + ax::ViewAxis{OldInds, IdxMap, Ax}, + offset + ) where {OldInds, IdxMap, Ax <: Union{Shaped1DAxis, ShapedAxis}} NewInds = viewindex(ax) .+ offset return ViewAxis(NewInds, Ax()) end @@ -181,14 +192,19 @@ end @inline Base.getindex(ax::AbstractAxis, ::Colon) = ComponentIndex(:, ax) @inline Base.getindex(::AbstractAxis{IdxMap}, s::Symbol) where {IdxMap} = ComponentIndex(getproperty(IdxMap, s)) @inline Base.getindex( - ::AbstractAxis{IdxMap}, ::Val{s}) where { - IdxMap, s} = ComponentIndex(getproperty(IdxMap, s)) -function Base.getindex(ax::AbstractAxis, syms::Union{ - NTuple{N, Symbol}, <:AbstractArray{Symbol}}) where {N} + ::AbstractAxis{IdxMap}, ::Val{s} +) where { + IdxMap, s, +} = ComponentIndex(getproperty(IdxMap, s)) +function Base.getindex( + ax::AbstractAxis, syms::Union{ + NTuple{N, Symbol}, <:AbstractArray{Symbol}, + } + ) where {N} @assert allunique(syms) "Indexing symbols must all be unique. Got $syms" c_inds = getindex.((ax,), syms) - inds = map(x->x.idx, c_inds) - axs = map(x->x.ax, c_inds) + inds = map(x -> x.idx, c_inds) + axs = map(x -> x.ax, c_inds) last_index = 0 new_axs = map(inds, axs) do i, ax first_index = last_index + 1 @@ -231,5 +247,5 @@ Base.getindex(ax::CombinedAxis, i::AbstractArray) = _array_axis(ax)[i] Base.length(ax::CombinedAxis) = lastindex(ax) - firstindex(ax) + 1 function Base.CartesianIndices(ax::Tuple{CombinedAxis, Vararg{CombinedAxis}}) - CartesianIndices(_array_axis.(ax)) + return CartesianIndices(_array_axis.(ax)) end diff --git a/src/broadcasting.jl b/src/broadcasting.jl index 9de5c740..306f29b1 100644 --- a/src/broadcasting.jl +++ b/src/broadcasting.jl @@ -1,38 +1,47 @@ function Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} - Broadcast.BroadcastStyle(A) + return Broadcast.BroadcastStyle(A) end # Need special case here for adjoint vectors in order to avoid type instability in axistype function Broadcast.combine_axes(a::ComponentArray, b::AdjOrTransComponentVector) - (axes(a)[1], axes(b)[2]) + return (axes(a)[1], axes(b)[2]) end function Broadcast.combine_axes(a::AdjOrTransComponentVector, b::ComponentArray) - (axes(b)[2], axes(a)[1]) + return (axes(b)[2], axes(a)[1]) end Broadcast.axistype(a::CombinedAxis, b::AbstractUnitRange) = a Broadcast.axistype(a::AbstractUnitRange, b::CombinedAxis) = b function Broadcast.axistype(a::CombinedAxis, b::CombinedAxis) - CombinedAxis(FlatAxis(), Base.Broadcast.axistype(_array_axis(a), _array_axis(b))) + return CombinedAxis(FlatAxis(), Base.Broadcast.axistype(_array_axis(a), _array_axis(b))) end Broadcast.axistype(a::T, b::T) where {T <: CombinedAxis} = a -function Base.promote_shape(a::NTuple{M, CombinedAxis}, b::NTuple{ - N, AbstractUnitRange}) where {M, N} - Base.promote_shape(_array_axis.(a), b) -end -function Base.promote_shape(a::NTuple{N, AbstractUnitRange}, b::NTuple{ - M, CombinedAxis}) where {M, N} - Base.promote_shape(a, _array_axis.(b)) -end -function Base.promote_shape(a::NTuple{M, CombinedAxis}, b::NTuple{ - N, CombinedAxis}) where {M, N} - Base.promote_shape(_array_axis.(a), _array_axis.(b)) -end -Base.promote_shape(a::T, b::T) where {T <: NTuple{N, CombinedAxis} where N} = a +function Base.promote_shape( + a::NTuple{M, CombinedAxis}, b::NTuple{ + N, AbstractUnitRange, + } + ) where {M, N} + return Base.promote_shape(_array_axis.(a), b) +end +function Base.promote_shape( + a::NTuple{N, AbstractUnitRange}, b::NTuple{ + M, CombinedAxis, + } + ) where {M, N} + return Base.promote_shape(a, _array_axis.(b)) +end +function Base.promote_shape( + a::NTuple{M, CombinedAxis}, b::NTuple{ + N, CombinedAxis, + } + ) where {M, N} + return Base.promote_shape(_array_axis.(a), _array_axis.(b)) +end +Base.promote_shape(a::T, b::T) where {T <: NTuple{N, CombinedAxis} where {N}} = a # From https://github.com/JuliaArrays/OffsetArrays.jl/blob/master/src/OffsetArrays.jl Base.dataids(A::ComponentArray) = Base.dataids(parent(A)) function Broadcast.broadcast_unalias(dest::ComponentArray, src) - getdata(dest) === getdata(src) ? src : Broadcast.unalias(dest, src) + return getdata(dest) === getdata(src) ? src : Broadcast.unalias(dest, src) end diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index 11e9899c..7c3e5f83 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -1,5 +1,8 @@ -function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{ - Symbol, Val}) +function ChainRulesCore.rrule( + ::typeof(getproperty), x::ComponentArray, s::Union{ + Symbol, Val, + } + ) return getproperty(x, s), Δ -> getproperty_adjoint(ChainRulesCore.unthunk(Δ), x, s) end @@ -25,36 +28,43 @@ end function ChainRulesCore.rrule( cfg::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, - ::typeof(__setproperty!), x, s, Δ) + ::typeof(__setproperty!), x, s, Δ + ) y_, pb_f = ChainRulesCore.rrule_via_ad(cfg, __setproperty!, Val(true), x, s, Δ) return y_, pb_f end function ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) - getdata(x), - Δ -> (ChainRulesCore.NoTangent(), ComponentArray(ChainRulesCore.unthunk(Δ), getaxes(x))) + return getdata(x), + Δ -> (ChainRulesCore.NoTangent(), ComponentArray(ChainRulesCore.unthunk(Δ), getaxes(x))) end function ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) - ComponentArray(data, axes), - Δ -> (ChainRulesCore.NoTangent(), getdata(ChainRulesCore.unthunk(Δ)), - ChainRulesCore.NoTangent()) + return ComponentArray(data, axes), + Δ -> ( + ChainRulesCore.NoTangent(), getdata(ChainRulesCore.unthunk(Δ)), + ChainRulesCore.NoTangent(), + ) end function ChainRulesCore.ProjectTo(ca::ComponentArray) return ChainRulesCore.ProjectTo{ComponentArray}(; - project = ChainRulesCore.ProjectTo(getdata(ca)), axes = getaxes(ca)) + project = ChainRulesCore.ProjectTo(getdata(ca)), axes = getaxes(ca) + ) end function (p::ChainRulesCore.ProjectTo{ComponentArray})(dx::AbstractArray) - ComponentArray(p.project(dx), p.axes) + return ComponentArray(p.project(dx), p.axes) end # Prevent double projection (p::ChainRulesCore.ProjectTo{ComponentArray})(dx::ComponentArray) = dx -function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{ - A, <:NamedTuple}) where {A} +function (p::ChainRulesCore.ProjectTo{ComponentArray})( + t::ChainRulesCore.Tangent{ + A, <:NamedTuple, + } + ) where {A} nt = Functors.fmap(ChainRulesCore.backing, ChainRulesCore.backing(t)) return ComponentArray(nt) end @@ -68,8 +78,10 @@ function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA <: Component if length(Δ) == length(y) return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y))) end - error("Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " * - "of shape $(size(y)) & type $(typeof(y))") + error( + "Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " * + "of shape $(size(y)) & type $(typeof(y))" + ) return nothing end diff --git a/src/compat/static_arrays.jl b/src/compat/static_arrays.jl index 26b56cd0..e984361d 100644 --- a/src/compat/static_arrays.jl +++ b/src/compat/static_arrays.jl @@ -1,11 +1,12 @@ function ComponentArray{A}(::UndefInitializer, ax::Axes) where { - A <: StaticArray, Axes <: Tuple} + A <: StaticArray, Axes <: Tuple, + } return ComponentArray(similar(A), ax...) end _maybe_SArray(x::SubArray, ::Val{N}, ::FlatAxis) where {N} = SVector{N}(x) function _maybe_SArray(x::Base.ReshapedArray, ::Val, ::ShapedAxis{Sz}) where {Sz} - SArray{Tuple{Sz...}}(x) + return SArray{Tuple{Sz...}}(x) end _maybe_SArray(x, ::Val, ::Shaped1DAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x) _maybe_SArray(x, vals...) = x diff --git a/src/componentarray.jl b/src/componentarray.jl index 2df25227..dd5a923e 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -33,7 +33,8 @@ julia> collect(x) ``` """ struct ComponentArray{ - T, N, A <: AbstractArray{T, N}, Axes <: Tuple{Vararg{AbstractAxis}}} <: DenseArray{T, N} + T, N, A <: AbstractArray{T, N}, Axes <: Tuple{Vararg{AbstractAxis}}, + } <: DenseArray{T, N} data::A axes::Axes end @@ -41,14 +42,15 @@ end # Entry from type (used for broadcasting) ComponentArray{Axes}(data) where {Axes} = ComponentArray(data, getaxes(Axes)...) function ComponentArray(::UndefInitializer, ax::Axes) where {Axes <: Tuple} - ComponentArray(similar(Array{Float64}, last_index.(ax)), ax...) + return ComponentArray(similar(Array{Float64}, last_index.(ax)), ax...) end function ComponentArray{A}(::UndefInitializer, ax::Axes) where { - A <: AbstractArray, Axes <: Tuple} - ComponentArray(similar(A, last_index.(ax)), ax...) + A <: AbstractArray, Axes <: Tuple, + } + return ComponentArray(similar(A, last_index.(ax)), ax...) end function ComponentArray{T}(::UndefInitializer, ax::Axes) where {T, Axes <: Tuple} - ComponentArray(similar(Array{T}, last_index.(ax)), ax...) + return ComponentArray(similar(Array{T}, last_index.(ax)), ax...) end # Entry from data array and AbstractAxis types dispatches to correct shapes and partitions @@ -57,7 +59,7 @@ end ComponentArray(data, ::Union{FlatAxis, Shaped1DAxis}...) = data ComponentArray(data, ax::NotShapedOrPartitionedAxis...) = ComponentArray(data, ax) function ComponentArray(data, ax::NotPartitionedAxis...) - ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...) + return ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...) end function ComponentArray(data, ax::AbstractAxis...) part_axs = filter_by_type(PartitionedAxis, ax...) @@ -79,9 +81,11 @@ function Adapt.adapt_structure(to, x::ComponentArray) return ComponentArray(data, getaxes(x)) end -function Adapt.adapt_storage(::Type{ComponentArray{T, N, A, Ax}}, - xs::AT) where {T, N, A, Ax, AT <: AbstractArray} - Adapt.adapt_storage(A, xs) +function Adapt.adapt_storage( + ::Type{ComponentArray{T, N, A, Ax}}, + xs::AT + ) where {T, N, A, Ax, AT <: AbstractArray} + return Adapt.adapt_storage(A, xs) end Adapt.parent_type(::Type{ComponentArray{T, N, A, Ax}}) where {T, N, A, Ax} = A @@ -90,7 +94,7 @@ Adapt.parent_type(::Type{ComponentArray{T, N, A, Ax}}) where {T, N, A, Ax} = A ComponentArray{T}(nt::NamedTuple) where {T} = ComponentArray(make_carray_args(T, nt)...) ComponentArray{T}(::NamedTuple{(), Tuple{}}) where {T} = ComponentArray(T[], (FlatAxis(),)) function ComponentArray(nt::Union{NamedTuple, AbstractDict}) - ComponentArray(make_carray_args(nt)...) + return ComponentArray(make_carray_args(nt)...) end ComponentArray(::NamedTuple{(), Tuple{}}) = ComponentArray(Any[], (FlatAxis(),)) ComponentArray{T}(; kwargs...) where {T} = ComponentArray{T}((; kwargs...)) @@ -99,7 +103,7 @@ ComponentArray(; kwargs...) = ComponentArray((; kwargs...)) ComponentArray(x::ComponentArray) = x ComponentArray{T}(x::ComponentArray) where {T} = T.(x) function (CA::Type{<:ComponentArray{T, N, A, Ax}})(x::ComponentArray) where {T, N, A, Ax} - ComponentArray(T.(getdata(x)), getaxes(x)) + return ComponentArray(T.(getdata(x)), getaxes(x)) end function fill_componentarray_ka! end # defined in extensions @@ -125,7 +129,7 @@ function ComponentVector(data::AbstractArray, ax) end function ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) - ComponentVector(x; patch...) + return ComponentVector(x; patch...) end # Add new fields to component Vector @@ -154,7 +158,7 @@ ComponentMatrix{T}(x::ComponentMatrix) where {T} = T.(x) ComponentMatrix() = ComponentMatrix(Array{Any}(undef, 0, 0), (FlatAxis(), FlatAxis())) function ComponentMatrix{T}() where {T} - ComponentMatrix(Array{T}(undef, 0, 0), (FlatAxis(), FlatAxis())) + return ComponentMatrix(Array{T}(undef, 0, 0), (FlatAxis(), FlatAxis())) end const CArray = ComponentArray @@ -163,22 +167,29 @@ const CMatrix = ComponentMatrix const AdjOrTrans{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} const AdjOrTransComponentArray{ - T, A} = Union{Adjoint{T, A}, Transpose{T, A}} where {A <: ComponentArray} + T, A, +} = Union{Adjoint{T, A}, Transpose{T, A}} where {A <: ComponentArray} const AdjOrTransComponentVector{T} = Union{ - Adjoint{T, A}, Transpose{T, A}} where {A <: ComponentVector} + Adjoint{T, A}, Transpose{T, A}, +} where {A <: ComponentVector} const AdjOrTransComponentMatrix{T} = Union{ - Adjoint{T, A}, Transpose{T, A}} where {A <: ComponentMatrix} + Adjoint{T, A}, Transpose{T, A}, +} where {A <: ComponentMatrix} const ComponentVecOrMat{T} = Union{ComponentVector{T}, ComponentMatrix{T}} where {T} const AdjOrTransComponentVecOrMat{T} = AdjOrTrans{T, <:ComponentVecOrMat} where {T} const AbstractComponentArray{T} = Union{ - ComponentArray{T}, AdjOrTransComponentArray{T}} where {T} + ComponentArray{T}, AdjOrTransComponentArray{T}, +} where {T} const AbstractComponentVecOrMat{T} = Union{ - ComponentVecOrMat{T}, AdjOrTransComponentVecOrMat{T}} where {T} + ComponentVecOrMat{T}, AdjOrTransComponentVecOrMat{T}, +} where {T} const AbstractComponentVector{T} = Union{ - ComponentVector{T}, AdjOrTransComponentVector{T}} where {T} + ComponentVector{T}, AdjOrTransComponentVector{T}, +} where {T} const AbstractComponentMatrix{T} = Union{ - ComponentMatrix{T}, AdjOrTransComponentMatrix{T}} where {T} + ComponentMatrix{T}, AdjOrTransComponentMatrix{T}, +} where {T} ## Constructor helpers allocate_numeric_container(x) = allocate_numeric_container(recursive_eltype(x)) @@ -190,7 +201,7 @@ make_carray_args(::NamedTuple{(), Tuple{}}) = (Any[], FlatAxis()) make_carray_args(::Type{T}, ::NamedTuple{(), Tuple{}}) where {T} = (T[], FlatAxis()) function make_carray_args(nt) data, ax = make_carray_args(Vector, nt) - data = length(data)==1 ? [data[1]] : map(identity, data) + data = length(data) == 1 ? [data[1]] : map(identity, data) return (data, ax) end make_carray_args(::Type{T}, nt) where {T} = make_carray_args(Vector{T}, nt) @@ -204,14 +215,16 @@ end function make_idx(data, nt::Union{NamedTuple, AbstractDict}, last_val) len = recursive_length(nt) lv = Ref(0) # workaround for https://github.com/JuliaLang/julia/issues/15276 - kvs = (; ( - k => begin - inds = make_idx(data, v, lv[])[2] - lv[] = last_index(inds) - inds - end - for (k, v) in pairs(nt) - )...) + kvs = (; + ( + k => begin + inds = make_idx(data, v, lv[])[2] + lv[] = last_index(inds) + inds + end + for (k, v) in pairs(nt) + )..., + ) return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs)) end function make_idx(data, nt::NamedTuple{(), Tuple{}}, last_val) @@ -225,15 +238,15 @@ function make_idx(data, pair::Pair, last_val) end make_idx(data, x, last_val) = ( push!(data, x), - ViewAxis(last_index(last_val) + 1) + ViewAxis(last_index(last_val) + 1), ) function make_idx(data, x::ComponentVector, last_val) - ( + return ( append!(data, x), ViewAxis( last_index(last_val) .+ (1:length(x)), getaxes(x)[1] - ) + ), ) end function make_idx(data, x::AbstractArray, last_val) @@ -241,8 +254,13 @@ function make_idx(data, x::AbstractArray, last_val) out = last_index(last_val) .+ (1:length(x)) return (data, ViewAxis(out, ShapedAxis(size(x)))) end -function make_idx(data, x::A, last_val) where {A <: AbstractArray{<:Union{ - NamedTuple, AbstractArray}}} +function make_idx(data, x::A, last_val) where { + A <: AbstractArray{ + <:Union{ + NamedTuple, AbstractArray, + }, + }, + } len = recursive_length(x) elem_len = len ÷ length(x) if eltype(x) |> isconcretetype && all(elem -> recursive_length(elem) == elem_len, x) @@ -258,7 +276,7 @@ function make_idx(data, x::A, last_val) where {A <: AbstractArray{<:Union{ elem_len, indexmap(out) ) - ) + ), ) else error("Only homogeneous arrays are allowed.") @@ -270,7 +288,7 @@ end #TODO: Make all internal function names start with underscores function _maybe_add_field(x, pair) - haskey(x, pair.first) ? _update_field(x, pair) : _add_field(x, pair) + return haskey(x, pair.first) ? _update_field(x, pair) : _add_field(x, pair) end function _add_field(x, pair) data = copy(getdata(x)) @@ -294,7 +312,7 @@ function maybe_reshape(data, axs::AbstractAxis...) end function Base.reshape(A::AbstractArray, axs::Tuple{CombinedAxis, Vararg{CombinedAxis}}) - reshape(A, _array_axis.(axs)) + return reshape(A, _array_axis.(axs)) end # Recurse through nested ViewAxis types to find the last index @@ -353,20 +371,38 @@ julia> getaxes(ca) ``` """ @inline getaxes(x::ComponentArray) = getfield(x, :axes) -@inline getaxes(x::AdjOrTrans{ - T, <:ComponentVector}) where {T} = (FlatAxis(), getaxes(x.parent)[1]) +@inline getaxes( + x::AdjOrTrans{ + T, <:ComponentVector, + } +) where {T} = (FlatAxis(), getaxes(x.parent)[1]) @inline getaxes(x::AdjOrTrans{T, <:ComponentMatrix}) where {T} = reverse(getaxes(x.parent)) -@inline getaxes(::Type{<:ComponentArray{ - T, N, A, Axes}}) where {T, N, A, Axes} = map(x->x(), (Axes.types...,)) -@inline getaxes(::Type{<:AdjOrTrans{ - T, CA}}) where {T, CA <: ComponentVector} = (FlatAxis(), getaxes(CA)[1]) |> typeof -@inline getaxes(::Type{<:AdjOrTrans{ - T, CA}}) where {T, CA <: ComponentMatrix} = reverse(getaxes(CA)) |> typeof +@inline getaxes( + ::Type{ + <:ComponentArray{ + T, N, A, Axes, + }, + } +) where {T, N, A, Axes} = map(x -> x(), (Axes.types...,)) +@inline getaxes( + ::Type{ + <:AdjOrTrans{ + T, CA, + }, + } +) where {T, CA <: ComponentVector} = (FlatAxis(), getaxes(CA)[1]) |> typeof +@inline getaxes( + ::Type{ + <:AdjOrTrans{ + T, CA, + }, + } +) where {T, CA <: ComponentMatrix} = reverse(getaxes(CA)) |> typeof ## Field access through these functions to reserve dot-getting for keys @inline getaxes(x::VarAxes) = getaxes(typeof(x)) -@inline getaxes(Ax::Type{Axes}) where {Axes <: VarAxes} = map(x->x(), (Ax.types...,)) +@inline getaxes(Ax::Type{Axes}) where {Axes <: VarAxes} = map(x -> x(), (Ax.types...,)) getaxes(x) = () diff --git a/src/componentindex.jl b/src/componentindex.jl index e69acd7a..2697d265 100644 --- a/src/componentindex.jl +++ b/src/componentindex.jl @@ -7,14 +7,14 @@ ComponentIndex(idx::CartesianIndex) = ComponentIndex(idx, ShapedAxis((1,))) ComponentIndex(idx::AbstractArray{<:Integer}) = ComponentIndex(idx, ShapedAxis(size(idx))) ComponentIndex(idx::Int) = ComponentIndex(idx, NullAxis()) function ComponentIndex(vax::ViewAxis{Inds, IdxMap, Ax}) where {Inds, IdxMap, Ax} - ComponentIndex(Inds, vax.ax) + return ComponentIndex(Inds, vax.ax) end const FlatComponentIndex{Idx} = ComponentIndex{Idx, FlatAxis} const NullComponentIndex{Idx} = ComponentIndex{Idx, NullAxis} function Base.:(==)(ci1::ComponentIndex, ci2::ComponentIndex) - ci1.idx == ci2.idx && ci1.ax == ci2.ax + return ci1.idx == ci2.idx && ci1.ax == ci2.ax end Base.length(ci::ComponentIndex) = length(ci.idx) @@ -33,10 +33,12 @@ Base.getindex(ax::AbstractAxis, i::KeepIndex{Idx}) where {Idx} = _getindex_keep( _getindex_keep(ax::AbstractAxis, ::Colon) = ComponentIndex(:, ax) function _getindex_keep(ax::AbstractAxis, idx::AbstractRange) idx_map = indexmap(ax) - keeps = (s=>x for (s, x) in pairs(idx_map) if first(viewindex(x)) in idx && - last(viewindex(x)) in idx) + keeps = ( + s => x for (s, x) in pairs(idx_map) if first(viewindex(x)) in idx && + last(viewindex(x)) in idx + ) keeps = NamedTuple{Tuple(first.(keeps))}(Tuple(last.(keeps))) - new_ax = reindex(Axis(keeps), -first(idx)+1) + new_ax = reindex(Axis(keeps), -first(idx) + 1) return ComponentIndex(idx, new_ax) end function _getindex_keep(ax::AbstractAxis, sym::Symbol) @@ -50,6 +52,6 @@ function _getindex_keep(ax::AbstractAxis, sym::Symbol) else new_ax = Axis(NamedTuple{(sym,)}((ViewAxis(idx, ci.ax),))) end - new_ax = reindex(new_ax, -first(idx)+1) + new_ax = reindex(new_ax, -first(idx) + 1) return ComponentIndex(idx, new_ax) end diff --git a/src/lazyarray.jl b/src/lazyarray.jl index 5df7aae8..6cc1144f 100644 --- a/src/lazyarray.jl +++ b/src/lazyarray.jl @@ -8,7 +8,7 @@ struct LazyArray{T, N, G} <: AbstractArray{T, N} gen::G LazyArray{T}(gen) where {T} = new{T, ndims(gen), typeof(gen)}(gen) function LazyArray(gen::Base.Generator{A, F}) where {A, F} - new{eltype(A), ndims(gen), typeof(gen)}(gen) + return new{eltype(A), ndims(gen), typeof(gen)}(gen) end end @@ -18,7 +18,7 @@ const LazyMatrix{T, G} = LazyArray{T, 2, G} Base.getindex(a::LazyArray, i...) = _un_iter(getfield(a, :gen), i) function Base.setindex!(a::LazyArray, val, i...) - a[i...] .= val + return a[i...] .= val end _un_iter(iter, idxs) = _un_iter(iter.f, iter.iter, idxs) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index aee46887..b5302335 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -38,7 +38,8 @@ for op in [:*, :\, :/] return $adj(ComponentArray(cᵀ', ax2)) end function Base.$op(A::$Adj{T, <:CV}, B::CV) where { - T <: Real, CV <: ComponentVector{T}} + T <: Real, CV <: ComponentVector{T}, + } return $op(getdata(A), getdata(B)) end end @@ -60,5 +61,5 @@ function LinearAlgebra.axpby!(α::Number, x::ComponentArray, β::Number, y::Comp end function LinearAlgebra.ldiv!(B::AbstractVecOrMat, D::Diagonal{Float64, <:ComponentArray}, A::AbstractVecOrMat) - ldiv!(B, Diagonal(Vector(D.diag)), A) + return ldiv!(B, Diagonal(Vector(D.diag)), A) end diff --git a/src/namedtuple_interface.jl b/src/namedtuple_interface.jl index 4967faa1..eba7bb57 100644 --- a/src/namedtuple_interface.jl +++ b/src/namedtuple_interface.jl @@ -1,10 +1,10 @@ Base.hash(x::ComponentArray, h::UInt) = hash(keys(x), hash(getdata(x), h)) function Base.:(==)(x::ComponentArray, y::ComponentArray) - getdata(x)==getdata(y) && getaxes(x)==getaxes(y) + return getdata(x) == getdata(y) && getaxes(x) == getaxes(y) end -Base.:(==)(x::ComponentArray, y::AbstractArray) = getdata(x)==y && keys(x)==keys(y) # For equality with LabelledArrays -Base.:(==)(x::AbstractArray, y::ComponentArray) = y==x +Base.:(==)(x::ComponentArray, y::AbstractArray) = getdata(x) == y && keys(x) == keys(y) # For equality with LabelledArrays +Base.:(==)(x::AbstractArray, y::ComponentArray) = y == x Base.keys(x::ComponentVector) = keys(indexmap(getaxes(x)[1])) diff --git a/src/plot_utils.jl b/src/plot_utils.jl index cfca240a..9d95f05f 100644 --- a/src/plot_utils.jl +++ b/src/plot_utils.jl @@ -1,4 +1,3 @@ - """ labels(x::ComponentVector) @@ -31,16 +30,16 @@ julia> ComponentArrays.labels(x) see also [`label2index`](@ref) """ -labels(x::ComponentVector) = map(x->x[(firstindex(x) + 1):end], _labels(x)) -labels(x) = map(x->x[firstindex(x):end], _labels(x)) +labels(x::ComponentVector) = map(x -> x[(firstindex(x) + 1):end], _labels(x)) +labels(x) = map(x -> x[firstindex(x):end], _labels(x)) _labels(x::ComponentVector) = vcat((".$(key)" .* _labels(x[key]) for key in keys(x))...) function _labels(x::AbstractArray{<:ComponentArray}) - vcat(("[$i]" .* _labels(x[i]) for i in eachindex(x))...) + return vcat(("[$i]" .* _labels(x[i]) for i in eachindex(x))...) end _labels(x::LazyArray) = vcat(("[$i]" .* _labels(x[i]) for i in eachindex(x))...) function _labels(x::AbstractArray) - vcat(("[" * join(i.I, ",") * "]" for i in CartesianIndices(x))...) + return vcat(("[" * join(i.I, ",") * "]" for i in CartesianIndices(x))...) end _labels(x) = "" diff --git a/src/show.jl b/src/show.jl index a650573e..f41988b8 100644 --- a/src/show.jl +++ b/src/show.jl @@ -1,6 +1,6 @@ # Show AbstractAxis types function Base.show(io::IO, ::MIME"text/plain", ::Axis{IdxMap}) where {IdxMap} - print(io, "Axis$IdxMap") + return print(io, "Axis$IdxMap") end Base.show(io::IO, ::Axis{IdxMap}) where {IdxMap} = print(io, "Axis$IdxMap") @@ -9,60 +9,68 @@ Base.show(io::IO, ::MIME"text/plain", ::FlatAxis) = print(io, "FlatAxis()") Base.show(io::IO, ::NullAxis) = print(io, "NullAxis()") -function Base.show(io::IO, ::MIME"text/plain", ::PartitionedAxis{ - PartSz, IdxMap, Ax}) where {PartSz, IdxMap, Ax} - print(io, "PartitionedAxis($PartSz, $(Ax()))") +function Base.show( + io::IO, ::MIME"text/plain", ::PartitionedAxis{ + PartSz, IdxMap, Ax, + } + ) where {PartSz, IdxMap, Ax} + return print(io, "PartitionedAxis($PartSz, $(Ax()))") end function Base.show(io::IO, ::PartitionedAxis{PartSz, IdxMap, Ax}) where {PartSz, IdxMap, Ax} - print(io, "PartitionedAxis($PartSz, $(Ax()))") + return print(io, "PartitionedAxis($PartSz, $(Ax()))") end Base.show(io::IO, ::ShapedAxis{Shape}) where {Shape} = print(io, "ShapedAxis($Shape)") Base.show(io::IO, ::Shaped1DAxis{Shape}) where {Shape} = print(io, "Shaped1DAxis($Shape)") -function Base.show(io::IO, ::MIME"text/plain", ::ViewAxis{ - Inds, IdxMap, Ax}) where {Inds, IdxMap, Ax} - print(io, "ViewAxis($Inds, $(Ax()))") +function Base.show( + io::IO, ::MIME"text/plain", ::ViewAxis{ + Inds, IdxMap, Ax, + } + ) where {Inds, IdxMap, Ax} + return print(io, "ViewAxis($Inds, $(Ax()))") end function Base.show(io::IO, ::ViewAxis{Inds, IdxMap, <:Ax}) where {Inds, IdxMap, Ax} - print(io, "ViewAxis($Inds, $(Ax()))") + return print(io, "ViewAxis($Inds, $(Ax()))") end function Base.show(io::IO, ::ViewAxis{Inds, IdxMap, <:NullorFlatAxis}) where {Inds, IdxMap} - print(io, Inds) + return print(io, Inds) end Base.show(io::IO, ci::ComponentIndex) = print(io, "ComponentIndex($(ci.idx), $(ci.ax))") # Show ComponentArrays function _print_type_short(io, ca; color = :normal) - _print_type_short(io, typeof(ca); color = color) + return _print_type_short(io, typeof(ca); color = color) end _print_type_short(io, T::Type; color = :normal) = printstyled(io, T; color = color) function _print_type_short(io, ::Type{<:ComponentArray{T, N, <:Array}}; color = :normal) where { - T, N} - printstyled(io, "ComponentArray{$T,$N}"; color = color) + T, N, + } + return printstyled(io, "ComponentArray{$T,$N}"; color = color) end # do not pollute the stacktrace with verbose type printing function _print_type_short(io, ::Type{<:ComponentArray{T, 1, <:Array}}; color = :normal) where {T} - printstyled(io, "ComponentVector{$T}"; color = color) + return printstyled(io, "ComponentVector{$T}"; color = color) end function _print_type_short(io, ::Type{<:ComponentArray{T, 2, <:Array}}; color = :normal) where {T} - printstyled(io, "ComponentMatrix{$T}"; color = color) + return printstyled(io, "ComponentMatrix{$T}"; color = color) end function _print_type_short(io, ::Type{<:ComponentArray{T, N, <:SubArray}}; color = :normal) where { - T, N} - printstyled(io, "ComponentArray{$T,$N,SubArray...}"; color = color) + T, N, + } + return printstyled(io, "ComponentArray{$T,$N,SubArray...}"; color = color) end # do not pollute the stacktrace with verbose type printing function _print_type_short(io, ::Type{<:ComponentArray{T, 1, <:SubArray}}; color = :normal) where {T} - printstyled(io, "ComponentVector{$T,SubArray...}"; color = color) + return printstyled(io, "ComponentVector{$T,SubArray...}"; color = color) end function _print_type_short(io, ::Type{<:ComponentArray{T, 2, <:SubArray}}; color = :normal) where {T} - printstyled(io, "ComponentMatrix{$T,SubArray...}"; color = color) + return printstyled(io, "ComponentMatrix{$T,SubArray...}"; color = color) end function Base.show(io::IO, x::ComponentVector) print(io, "(") for (i, key) in enumerate(keys(x)) - if i==1 + if i == 1 print(io, "$key = ") else print(io, ", $key = ") @@ -88,8 +96,11 @@ function Base.show(io::IO, mime::MIME"text/plain", x::ComponentVector) return nothing end -function Base.show(io::IO, ::MIME"text/plain", x::ComponentMatrix{ - T, A, Axes}) where {T, A, Axes} +function Base.show( + io::IO, ::MIME"text/plain", x::ComponentMatrix{ + T, A, Axes, + } + ) where {T, A, Axes} if !haskey(io, :compact) && length(axes(x, 2)) > 1 io = IOContext(io, :compact => true) end diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index 85d19dcd..1a9b5f77 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -1,42 +1,44 @@ const CombinedAnyDims = Tuple{<:CombinedAxis, Vararg{CombinedOrRegularAxis}} const AnyCombinedAnyDims = Tuple{ - <:CombinedOrRegularAxis, <:CombinedAxis, Vararg{CombinedOrRegularAxis}} + <:CombinedOrRegularAxis, <:CombinedAxis, Vararg{CombinedOrRegularAxis}, +} const CombinedCombinedAnyDims = Tuple{ - <:CombinedAxis, <:CombinedAxis, Vararg{CombinedOrRegularAxis}} + <:CombinedAxis, <:CombinedAxis, Vararg{CombinedOrRegularAxis}, +} # Similar Base.similar(x::ComponentArray) = ComponentArray(similar(getdata(x)), getaxes(x)...) function Base.similar(x::ComponentArray, ::Type{T}) where {T} - ComponentArray(similar(getdata(x), T), getaxes(x)...) + return ComponentArray(similar(getdata(x), T), getaxes(x)...) end Base.similar(x::ComponentArray, dims::Vararg{Int}) = similar(getdata(x), dims...) function Base.similar(x::ComponentArray, ::Type{T}, dims::Vararg{Int}) where {T} - similar(getdata(x), T, dims...) + return similar(getdata(x), T, dims...) end Base.similar(x::AbstractArray, dims::CombinedAnyDims) = _similar(x, dims) Base.similar(x::AbstractArray, dims::AnyCombinedAnyDims) = _similar(x, dims) Base.similar(x::AbstractArray, dims::CombinedCombinedAnyDims) = _similar(x, dims) function Base.similar(x::AbstractArray, ::Type{T}, dims::CombinedAnyDims) where {T} - _similar(x, T, dims) + return _similar(x, T, dims) end function Base.similar(x::AbstractArray, ::Type{T}, dims::AnyCombinedAnyDims) where {T} - _similar(x, T, dims) + return _similar(x, T, dims) end function Base.similar(x::AbstractArray, ::Type{T}, dims::CombinedCombinedAnyDims) where {T} - _similar(x, T, dims) + return _similar(x, T, dims) end Base.similar(x::Type{<:AbstractArray}, dims::CombinedAnyDims) = _similar(x, dims) Base.similar(x::Type{<:AbstractArray}, dims::AnyCombinedAnyDims) = _similar(x, dims) Base.similar(x::Type{<:AbstractArray}, dims::CombinedCombinedAnyDims) = _similar(x, dims) function _similar(x::AbstractArray, dims) - ComponentArray(similar(getdata(x), length.(_array_axis.(dims))), _component_axis.(dims)...) + return ComponentArray(similar(getdata(x), length.(_array_axis.(dims))), _component_axis.(dims)...) end function _similar(x::Type, dims) - ComponentArray(similar(x, length.(_array_axis.(dims))), _component_axis.(dims)...) + return ComponentArray(similar(x, length.(_array_axis.(dims))), _component_axis.(dims)...) end function _similar(x, T, dims) - ComponentArray(similar(getdata(x), T, length.(_array_axis.(dims))), _component_axis.(dims)...) + return ComponentArray(similar(getdata(x), T, length.(_array_axis.(dims))), _component_axis.(dims)...) end Base.zero(x::ComponentArray) = ComponentArray(zero(getdata(x)), getaxes(x)...) @@ -62,38 +64,53 @@ end Base.deepcopy(x::ComponentArray) = ComponentArray(deepcopy(getdata(x)), getaxes(x)) function Base.convert(::Type{ComponentArray{T, N, AA, Ax}}, A::AbstractArray) where { - T, N, AA, Ax} + T, N, AA, Ax, + } return ComponentArray{Ax}(A) end -function Base.convert(::Type{ComponentArray{T, N, A, Ax1}}, - x::ComponentArray{T, N, A, Ax2}) where {T, N, A, Ax1, Ax2} +function Base.convert( + ::Type{ComponentArray{T, N, A, Ax1}}, + x::ComponentArray{T, N, A, Ax2} + ) where {T, N, A, Ax1, Ax2} return x end -function Base.convert(::Type{ComponentArray{T1, N, A1, Ax1}}, - x::ComponentArray{T2, N, A2, Ax2}) where {T1, T2, N, A1, A2, Ax1, Ax2} +function Base.convert( + ::Type{ComponentArray{T1, N, A1, Ax1}}, + x::ComponentArray{T2, N, A2, Ax2} + ) where {T1, T2, N, A1, A2, Ax1, Ax2} return T1.(x) end -function Base.convert(::Type{ComponentArray{T, N, A1, Ax}}, - x::ComponentArray{T, N, A2, Ax}) where {T, N, A1, A2, Ax} +function Base.convert( + ::Type{ComponentArray{T, N, A1, Ax}}, + x::ComponentArray{T, N, A2, Ax} + ) where {T, N, A1, A2, Ax} return x end -function Base.convert(::Type{ComponentArray{T, N, A, Ax}}, x::ComponentArray{ - T, N, A, Ax}) where {T, N, A, Ax} +function Base.convert( + ::Type{ComponentArray{T, N, A, Ax}}, x::ComponentArray{ + T, N, A, Ax, + } + ) where {T, N, A, Ax} return x end Base.convert(T::Type{<:Array}, x::ComponentArray) = convert(T, getdata(x)) -function Base.convert(::Type{Cholesky{T1, Matrix{T1}}}, x::Cholesky{ - T2, <:ComponentArray}) where {T1, T2} - Cholesky(Matrix{T1}(x.factors), x.uplo, x.info) +function Base.convert( + ::Type{Cholesky{T1, Matrix{T1}}}, x::Cholesky{ + T2, <:ComponentArray, + } + ) where {T1, T2} + return Cholesky(Matrix{T1}(x.factors), x.uplo, x.info) end # Conversion to from ComponentArray to NamedTuple (note, does not preserve numeric types of # original NamedTuple) function _namedtuple(x::ComponentVector) - NamedTuple{keys(x)}(map(valkeys(x)) do key - _namedtuple(getproperty(x, key)) - end) + return NamedTuple{keys(x)}( + map(valkeys(x)) do key + _namedtuple(getproperty(x, key)) + end + ) end _namedtuple(v::AbstractVector) = _namedtuple.(v) _namedtuple(x) = x diff --git a/src/utils.jl b/src/utils.jl index 5b498e74..579dd35b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,16 +20,23 @@ function partition(m, N1, N2) ax = axes(m) firsts = firstindex.(ax) lasts = lastindex.(ax) - return (view(m, i:(i + N1 - 1), j:(j + N2 - 1)) for i in firsts[1]:N1:lasts[1], - j in firsts[2]:N2:lasts[2]) + return ( + view(m, i:(i + N1 - 1), j:(j + N2 - 1)) for i in firsts[1]:N1:lasts[1], + j in firsts[2]:N2:lasts[2] + ) end # Slower fallback for higher dimensions function partition(a::A, N::Tuple) where {A <: AbstractArray} ax = axes(a) offs = firstindex.(ax) - return (view(a, (:).((I.I .- 1) .* N .+ offs, ((I.I .- 1) .* N .+ N .- 1 .+ offs))...) for I in - CartesianIndices(div.( - size(a), N))) + return ( + view(a, (:).((I.I .- 1) .* N .+ offs, ((I.I .- 1) .* N .+ N .- 1 .+ offs))...) for I in + CartesianIndices( + div.( + size(a), N + ) + ) + ) end # partition(a::A, N::Tuple) where A<:AbstractVector = reshape(view(a, :), N) @@ -37,10 +44,10 @@ end filter_by_type(::Type{T}, args...) where {T} = filter_by_type(T, (), args...) filter_by_type(::Type{T}, part::Tuple) where {T} = part function filter_by_type(::Type{T}, part::Tuple, ax, args...) where {T} - filter_by_type(T, part, args...) + return filter_by_type(T, part, args...) end function filter_by_type(::Type{T}, part::Tuple, ax::T, args...) where {T} - filter_by_type(T, (part..., ax), args...) + return filter_by_type(T, (part..., ax), args...) end # Flat length of an arbitrarily nested named tuple @@ -53,13 +60,13 @@ recursive_length(nt::NamedTuple{(), Tuple{}}) = 0 # Find the highest element type function recursive_eltype(nt::NamedTuple) - isempty(nt) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, nt) + return isempty(nt) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, nt) end function recursive_eltype(x::AbstractArray{<:Any}) - isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, x) + return isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, x) end function recursive_eltype(x::Dict) - isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, values(x)) + return isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, values(x)) end recursive_eltype(::AbstractArray{T, N}) where {T <: Number, N} = T recursive_eltype(x) = typeof(x) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 339a3d0b..058f259e 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -28,8 +28,8 @@ truth = ComponentArray(a = [32, 48], x = 156) zygote_full = Zygote.gradient(F_, ca)[1] @test zygote_full ≈ truth - @test ComponentArray(x = 4.0,) ≈ Zygote.gradient(ComponentArray(x = 2,)) do c - (; c...,).x^2 + @test ComponentArray(x = 4.0) ≈ Zygote.gradient(ComponentArray(x = 2)) do c + (; c...).x^2 end[1] # Issue #148 @@ -131,5 +131,5 @@ end ps = ComponentArray(; a = rand(2), b = (; c = rand(2))) ps_tracked = Tracker.param(ps) @test ArrayInterface.restructure(ps, ps_tracked) isa - ComponentVector{<:Any, <:Tracker.TrackedArray} + ComponentVector{<:Any, <:Tracker.TrackedArray} end diff --git a/test/diffeq_test/diffeq_tests.jl b/test/diffeq_test/diffeq_tests.jl index 4e6d5644..fbaa4453 100644 --- a/test/diffeq_test/diffeq_tests.jl +++ b/test/diffeq_test/diffeq_tests.jl @@ -10,21 +10,21 @@ using Unitful y₁, y₂, y₃ = vars k₁, k₂, k₃ = p D = similar(vars) - D.y₁ = -k₁*y₁+k₃*y₂*y₃ - D.y₂ = k₁*y₁-k₂*y₂^2-k₃*y₂*y₃ - D.y₃ = k₂*y₂^2 + D.y₁ = -k₁ * y₁ + k₃ * y₂ * y₃ + D.y₂ = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ + D.y₃ = k₂ * y₂^2 return D end ic = ComponentArray(y₁ = 1.0, y₂ = 0.0, y₃ = 0.0) - prob = ODEProblem(rober, ic, (0.0, 1e11), (0.04, 3e7, 1e4)) + prob = ODEProblem(rober, ic, (0.0, 1.0e11), (0.04, 3.0e7, 1.0e4)) sol = solve(prob, Rosenbrock23()) @test sol[1] isa ComponentArray end @testset "Issue 53" begin x0 = ComponentArray(x = ones(10)) - prob = ODEProblem((u, p, t)->u, x0, (0.0, 1.0)) - sol = solve(prob, CVODE_BDF(linear_solver = :BCG), reltol = 1e-15, abstol = 1e-15) + prob = ODEProblem((u, p, t) -> u, x0, (0.0, 1.0)) + sol = solve(prob, CVODE_BDF(linear_solver = :BCG), reltol = 1.0e-15, abstol = 1.0e-15) @test sol(1)[1] ≈ exp(1) end @@ -65,34 +65,34 @@ end p = [0.1, 0.1] - lu_0 = @LArray fill(1000.0, 2*n) (x = (1:n), y = ((n + 1):(2 * n))) + lu_0 = @LArray fill(1000.0, 2 * n) (x = (1:n), y = ((n + 1):(2 * n))) cu_0 = ComponentArray(x = fill(1000.0, n), y = fill(1000.0, n)) lprob1 = ODEProblem(f1, lu_0, (0, 100.0), p) cprob1 = ODEProblem(f1, cu_0, (0, 100.0), p) - solve(lprob1, Rodas5()); - solve(lprob1, Rodas5(autodiff = false)); - solve(cprob1, Rodas5()); - solve(cprob1, Rodas5(autodiff = false)); + solve(lprob1, Rodas5()) + solve(lprob1, Rodas5(autodiff = false)) + solve(cprob1, Rodas5()) + solve(cprob1, Rodas5(autodiff = false)) - ltime1 = @elapsed lsol1 = solve(lprob1, Rodas5()); - ltime2 = @elapsed lsol2 = solve(lprob1, Rodas5(autodiff = false)); - ctime1 = @elapsed csol1 = solve(cprob1, Rodas5()); - ctime2 = @elapsed csol2 = solve(cprob1, Rodas5(autodiff = false)); + ltime1 = @elapsed lsol1 = solve(lprob1, Rodas5()) + ltime2 = @elapsed lsol2 = solve(lprob1, Rodas5(autodiff = false)) + ctime1 = @elapsed csol1 = solve(cprob1, Rodas5()) + ctime2 = @elapsed csol2 = solve(cprob1, Rodas5(autodiff = false)) - @test (ctime1 - ltime1)/ltime1 < 0.05 - @test (ctime2 - ltime2)/ltime2 < 0.05 + @test (ctime1 - ltime1) / ltime1 < 0.05 + @test (ctime2 - ltime2) / ltime2 < 0.05 end @testset "Slack Issue 2021-2-19" begin nknots = 100 - h² = (1.0/(nknots+1))^2 + h² = (1.0 / (nknots + 1))^2 function heat_conduction(du, u, p, t) u₃ = @view u[3:end] u₂ = @view u[2:(end - 1)] u₁ = @view u[1:(end - 2)] - @. du[2:(end - 1)] = (u₃ - 2*u₂ + u₁)/h² + @. du[2:(end - 1)] = (u₃ - 2 * u₂ + u₁) / h² nothing end @@ -113,7 +113,7 @@ end ltime = @elapsed solve(lprob, Tsit5(), saveat = 0.2) time = @elapsed solve(prob, Tsit5(), saveat = 0.2) - @test (ctime - time)/time < 0.1 - @test (ctime - ltime)/ltime < 0.05 + @test (ctime - time) / time < 0.1 + @test (ctime - ltime) / ltime < 0.05 end end diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index 7ec8d5c5..80334ba8 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -11,13 +11,13 @@ jlca = ComponentArray(jla, Axis(a = 1:2, b = 3:4)) @test getdata(map(identity, jlca)) isa JLArray @test all(==(0), map(-, jlca, jla)) @test all(map(-, jlca, jlca) .== 0) - @test all(==(0), map(-, jla, jlca)) broken=(pkgversion(JLArrays.GPUArrays) ≥ v"11") + @test all(==(0), map(-, jla, jlca)) broken = (pkgversion(JLArrays.GPUArrays) ≥ v"11") @test any(==(1), jlca) @test count(>(2), jlca) == 2 # Make sure mapreducing multiple arrays works - @test mapreduce(==,+,jlca,jla) == 4 + @test mapreduce(==, +, jlca, jla) == 4 @test mapreduce(abs2, +, jlca) == 30 @test all(map(sin, jlca) .== sin.(jlca) .== sin.(jla) .≈ sin.(1:4)) @@ -49,21 +49,21 @@ end @test rmul!(jlca3, 2) == ComponentArray(jla .* 2, Axis(a = 1:2, b = 3:4)) end @testset "mul!" begin - A = jlca .* jlca'; - @test_nowarn mul!(deepcopy(A), A, A, 1, 2); - @test_nowarn mul!(deepcopy(A), A', A', 1, 2); - @test_nowarn mul!(deepcopy(A), A', A, 1, 2); - @test_nowarn mul!(deepcopy(A), A, A', 1, 2); - @test_nowarn mul!(deepcopy(A), A, getdata(A'), 1, 2); - @test_nowarn mul!(deepcopy(A), getdata(A'), A, 1, 2); - @test_nowarn mul!(deepcopy(A), getdata(A'), getdata(A'), 1, 2); - @test_nowarn mul!(deepcopy(A), transpose(A), A, 1, 2); - @test_nowarn mul!(deepcopy(A), A, transpose(A), 1, 2); - @test_nowarn mul!(deepcopy(A), transpose(A), transpose(A), 1, 2); - @test_nowarn mul!(deepcopy(A), transpose(getdata(A)), A, 1, 2); - @test_nowarn mul!(deepcopy(A), A, transpose(getdata(A)), 1, 2); - @test_nowarn mul!(deepcopy(A), transpose(getdata(A)), transpose(getdata(A)), 1, 2); - @test_nowarn mul!(deepcopy(A), transpose(A), A', 1, 2); - @test_nowarn mul!(deepcopy(A), A', transpose(A), 1, 2); + A = jlca .* jlca' + @test_nowarn mul!(deepcopy(A), A, A, 1, 2) + @test_nowarn mul!(deepcopy(A), A', A', 1, 2) + @test_nowarn mul!(deepcopy(A), A', A, 1, 2) + @test_nowarn mul!(deepcopy(A), A, A', 1, 2) + @test_nowarn mul!(deepcopy(A), A, getdata(A'), 1, 2) + @test_nowarn mul!(deepcopy(A), getdata(A'), A, 1, 2) + @test_nowarn mul!(deepcopy(A), getdata(A'), getdata(A'), 1, 2) + @test_nowarn mul!(deepcopy(A), transpose(A), A, 1, 2) + @test_nowarn mul!(deepcopy(A), A, transpose(A), 1, 2) + @test_nowarn mul!(deepcopy(A), transpose(A), transpose(A), 1, 2) + @test_nowarn mul!(deepcopy(A), transpose(getdata(A)), A, 1, 2) + @test_nowarn mul!(deepcopy(A), A, transpose(getdata(A)), 1, 2) + @test_nowarn mul!(deepcopy(A), transpose(getdata(A)), transpose(getdata(A)), 1, 2) + @test_nowarn mul!(deepcopy(A), transpose(A), A', 1, 2) + @test_nowarn mul!(deepcopy(A), A', transpose(A), 1, 2) end end diff --git a/test/runtests.jl b/test/runtests.jl index 361bd735..22afc04f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,11 +39,18 @@ r2v(r::AbstractUnitRange) = ViewAxis(r, ShapedAxis(size(r))) ## Test setup c = (a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45]) nt = (a = 100, b = [4, 1.3], c = c) -nt2 = (a = 5, b = [(a = (a = 20, b = 1), b = 0), (a = (a = 33, b = 1), b = 0)], - c = (a = (a = 2, b = [1, 2]), b = [1.0 2.0; 5 6])) - -ax = Axis(a = 1, b = r2v(2:3), c = ViewAxis(4:10, ( - a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7)))) +nt2 = ( + a = 5, b = [(a = (a = 20, b = 1), b = 0), (a = (a = 33, b = 1), b = 0)], + c = (a = (a = 2, b = [1, 2]), b = [1.0 2.0; 5 6]), +) + +ax = Axis( + a = 1, b = r2v(2:3), c = ViewAxis( + 4:10, ( + a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7), + ) + ) +) ax_c = (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7)) a = Float64[100, 4, 1.3, 1, 1, 4.4, 0.4, 2, 1, 45] @@ -84,14 +91,17 @@ end @test_deprecated fastindices(:a, Val(:b)) == (Val(:a), Val(:b)) @test collect(ComponentArrays.partition(collect(1:12), 3)) == - [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] @test size(collect(ComponentArrays.partition(zeros(2, 2, 2), 1, 2, 2))[2, 1, 1]) == - (1, 2, 2) + (1, 2, 2) end @testset "Construction" begin - @test ca == ComponentArray(a = 100, b = [4, 1.3], c = ( - a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45])) + @test ca == ComponentArray( + a = 100, b = [4, 1.3], c = ( + a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45], + ) + ) @test ca_Float32 == ComponentArray(Float32.(a), ax) @test eltype(ComponentArray{ForwardDiff.Dual}(nt)) == ForwardDiff.Dual @test ca_composed.b isa ComponentArray @@ -108,8 +118,11 @@ end @test ComponentArray(dict1) isa ComponentArray @test ComponentArray(dict2).b isa ComponentArray - @test ca == ComponentVector(a = 100, b = [4, 1.3], c = ( - a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45])) + @test ca == ComponentVector( + a = 100, b = [4, 1.3], c = ( + a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45], + ) + ) @test cmat == ComponentMatrix(a .* a', ax, ax) @test_throws DimensionMismatch ComponentVector(sq_mat, ax) @test_throws DimensionMismatch ComponentMatrix(rand(11, 11, 11), ax, ax) @@ -124,7 +137,7 @@ end # Issue #24 @test ComponentVector(a = 1, b = 2.0f0) == ComponentVector{Float32}(a = 1.0, b = 2.0) @test ComponentVector(a = 1, b = 2 + im) == - ComponentVector{Complex{Int64}}(a = 1 + 0im, b = 2 + 1im) + ComponentVector{Complex{Int64}}(a = 1 + 0im, b = 2 + 1im) # Issue #23 sz = size(ca) @@ -170,7 +183,7 @@ end @test ComponentArray(a = T[], b = T[]) == ComponentVector{T}(a = T[], b = T[]) @test ComponentArray(a = T[], b = (;)) == ComponentVector{T}(a = T[], b = T[]) @test ComponentArray(a = Any[one(Int32)], b = T[]) == - ComponentVector{T}(a = [one(T)], b = T[]) + ComponentVector{T}(a = [one(T)], b = T[]) end @test ComponentArray(NamedTuple()) == ComponentVector{Any}() @test ComponentArray(a = []).a == [] @@ -410,25 +423,29 @@ end @testset "ComponentIndex" begin ax = getaxes(ca)[1] @test ax[:a] == ax[1] == - ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis()) + ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis()) @test ax[:c] == ax[3:4] == - ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4))) + ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4))) @test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4)) @test ax[(:a, :c)] == ax[[:a, :c]] == - ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3))) + ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3))) ax2 = getaxes(ca2)[1] @test ax2[(:a, :c)] == ax2[[:a, :c]] == - ComponentArrays.ComponentIndex( - [1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2, 3))))) + ComponentArrays.ComponentIndex( + [1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2, 3)))) + ) @test length(ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis())) == 1 @test length(ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4)))) == 2 @test length(ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4))) == - 4 + 4 @test length(ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))) == - 3 - @test length(ComponentArrays.ComponentIndex( - [1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2, 3)))))) == 7 + 3 + @test length( + ComponentArrays.ComponentIndex( + [1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2, 3)))) + ) + ) == 7 end @testset "KeepIndex" begin @@ -436,25 +453,25 @@ end @test ca[KeepIndex(:b)] == ca[KeepIndex(2)] == ComponentArray(b = 2) @test ca[KeepIndex(:c)] == ca[KeepIndex(3:4)] == ComponentArray(c = [3, 4]) @test ca[KeepIndex(:d)] == ca[KeepIndex(5:8)] == - ComponentArray(d = (a = [5, 6, 7], b = 8)) + ComponentArray(d = (a = [5, 6, 7], b = 8)) @test ca[KeepIndex(1:2)] == ComponentArray(a = 1, b = 2) @test ca[KeepIndex(1:3)] == ComponentArray([1, 2, 3], Axis(a = 1, b = 2)) # Drops c axis @test ca[KeepIndex(2:5)] == - ComponentArray([2, 3, 4, 5], Axis(b = 1, c = r2v(2:3))) + ComponentArray([2, 3, 4, 5], Axis(b = 1, c = r2v(2:3))) @test ca[KeepIndex(3:end)] == - ComponentArray(c = [3, 4], d = (a = [5, 6, 7], b = 8)) + ComponentArray(c = [3, 4], d = (a = [5, 6, 7], b = 8)) @test ca[KeepIndex(:)] == ca @test cmat[KeepIndex(:a), KeepIndex(:b)] == - ComponentArray(fill(2, 1, 1), Axis(a = 1), Axis(b = 1)) + ComponentArray(fill(2, 1, 1), Axis(a = 1), Axis(b = 1)) @test cmat[KeepIndex(:), KeepIndex(:c)] == - ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = r2v(1:2))) + ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = r2v(1:2))) @test cmat[KeepIndex(2:5), 1:2] == - ComponentArray((2:5) * (1:2)', Axis(b = 1, c = r2v(2:3)), ShapedAxis(size(1:2))) + ComponentArray((2:5) * (1:2)', Axis(b = 1, c = r2v(2:3)), ShapedAxis(size(1:2))) @test cmat[KeepIndex(2), KeepIndex(3)] == - ComponentArray(fill(2 * 3, 1, 1), Axis(b = 1), FlatAxis()) + ComponentArray(fill(2 * 3, 1, 1), Axis(b = 1), FlatAxis()) @test cmat[KeepIndex(2), 3] == ComponentArray(b = 2 * 3) end end @@ -610,10 +627,12 @@ end @test ldiv!(tempmat, lu(cmat + I), cmat) isa ComponentMatrix @test ldiv!(getdata(tempmat), lu(cmat + I), cmat) isa AbstractMatrix - c = (a = 2, b = [1, 2]); + c = (a = 2, b = [1, 2]) x = ComponentArray( a = 5, b = [ - (a = 20.0, b = 3.0), (a = 33.0, b = 2.0), (a = 44.0, b = 3.0)], c = c) + (a = 20.0, b = 3.0), (a = 33.0, b = 2.0), (a = 44.0, b = 3.0), + ], c = c + ) @test ldiv!(rand(10), Diagonal(x), x) isa Vector vca2 = vcat(ca2', ca2') @@ -693,7 +712,7 @@ end @test getaxes((s1_D * s2_D) * in2) == getaxes(s1_D * (s2_D * in2)) == (Axis(y1 = 1),) @test getaxes((s2_D * s1_D) * in1) == getaxes(s2_D * (s1_D * in1)) == (Axis(y2 = 1),) @test getaxes(out1' * (s1_D * s2_D)) == getaxes(transpose(out1) * (s1_D * s2_D)) == - (FlatAxis(), Axis(u2 = 1)) + (FlatAxis(), Axis(u2 = 1)) @test ComponentArrays.ArrayInterface.lu_instance(cmat).factors isa ComponentMatrix @test ComponentArrays.ArrayInterface.parent_type(cmat) === Matrix{Float64} @@ -735,14 +754,17 @@ end "c.b[1,1]", "c.b[2,1]", "c.b[1,2]", - "c.b[2,2]" + "c.b[2,2]", ] @test label2index(ca2, "c.b") == collect(11:14) # Issue #74 - lab2 = labels(ComponentArray( - a = 1, aa = ones(2), ab = [(a = 1, aa = ones(2)), (a = 1, aa = ones(2))], - ac = (a = 1, ab = ones(2, 2)))) + lab2 = labels( + ComponentArray( + a = 1, aa = ones(2), ab = [(a = 1, aa = ones(2)), (a = 1, aa = ones(2))], + ac = (a = 1, ab = ones(2, 2)) + ) + ) @test label2index(lab2, "a") == [1] @test label2index(lab2, "aa") == collect(2:3) @test label2index(lab2, "ab") == collect(4:9) @@ -757,7 +779,7 @@ end @test sum(abs2, cmat) == sum(abs2, getdata(cmat)) # Issue #40 - r0 = [1131.340, -2282.343, 6672.423]u"km" + r0 = [1131.34, -2282.343, 6672.423]u"km" v0 = [-5.64305, 4.30333, 2.42879]u"km/s" rv0 = ComponentArray(r = r0, v = v0) zrv0 = zero(rv0) @@ -882,20 +904,20 @@ end @test all(Xstack3_d1[4, :z] .== Xstack3_noca_d1[4, :]) # Issue #254, map then stack. - Xstack4_d1 = stack(x -> ComponentArray(a = x, b = [x+1, x+2]), [5 6; 7 8]; dims = 1) # map then stack - Xstack4_noca_d1 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims = 1) # map then stack + Xstack4_d1 = stack(x -> ComponentArray(a = x, b = [x + 1, x + 2]), [5 6; 7 8]; dims = 1) # map then stack + Xstack4_noca_d1 = stack(x -> [x, x + 1, x + 2], [5 6; 7 8]; dims = 1) # map then stack @test all(Xstack4_d1 .== Xstack4_noca_d1) @test all(Xstack4_d1[:, :a] .== Xstack4_noca_d1[:, 1]) @test all(Xstack4_d1[:, :b] .== Xstack4_noca_d1[:, 2:3]) - Xstack4_d2 = stack(x -> ComponentArray(a = x, b = [x+1, x+2]), [5 6; 7 8]; dims = 2) # map then stack - Xstack4_noca_d2 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims = 2) # map then stack + Xstack4_d2 = stack(x -> ComponentArray(a = x, b = [x + 1, x + 2]), [5 6; 7 8]; dims = 2) # map then stack + Xstack4_noca_d2 = stack(x -> [x, x + 1, x + 2], [5 6; 7 8]; dims = 2) # map then stack @test all(Xstack4_d2 .== Xstack4_noca_d2) @test all(Xstack4_d2[:a, :] .== Xstack4_noca_d2[1, :]) @test all(Xstack4_d2[:b, :] .== Xstack4_noca_d2[2:3, :]) - Xstack4_dcolon = stack(x -> ComponentArray(a = x, b = [x+1, x+2]), [5 6; 7 8]; dims = :) # map then stack - Xstack4_noca_dcolon = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims = :) # map then stack + Xstack4_dcolon = stack(x -> ComponentArray(a = x, b = [x + 1, x + 2]), [5 6; 7 8]; dims = :) # map then stack + Xstack4_noca_dcolon = stack(x -> [x, x + 1, x + 2], [5 6; 7 8]; dims = :) # map then stack @test all(Xstack4_dcolon .== Xstack4_noca_dcolon) @test all(Xstack4_dcolon[:a, :, :] .== Xstack4_noca_dcolon[1, :, :]) @test all(Xstack4_dcolon[:b, :, :] .== Xstack4_noca_dcolon[2:3, :, :]) @@ -924,7 +946,7 @@ end idx::Int, ::ComponentVector{A, B, <:Tuple{<:Axis{NT}}}, component_name::Symbol - ) where {A, B, NT} + ) where {A, B, NT} for (comp, range) in pairs(NT) if comp == component_name return range[idx]