Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
name = "ACEfit"
uuid = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c"
authors = ["William C Witt <[email protected]>, Christoph Ortner <[email protected]> and contributors"]
version = "0.1.4"
version = "0.1.5-DEV"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LowRankApprox = "898213cb-b102-5a47-900c-97e73b919f73"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParallelDataTransfer = "2dcacdae-9679-587a-88bb-8b444fb7085b"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
Expand All @@ -33,8 +31,7 @@ MLJLinearModels = "0.9"
MLJScikitLearnInterface = "0.5"
LowRankApprox = "0.5.3"
Optim = "1.7"
ParallelDataTransfer = "0.5.0"
ProgressMeter = "1.7"
ProgressMeter = "1.8"
PythonCall = "0.9"
StaticArrays = "1.5"

Expand Down
66 changes: 24 additions & 42 deletions src/assemble.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,38 @@
using Distributed
using ParallelDataTransfer
using ProgressMeter
using SharedArrays

struct DataPacket{T <: AbstractData}
rows::UnitRange
data::T
end

Base.length(d::DataPacket) = count_observations(d.data)

"""
assemble(data::AbstractArray, basis; kwargs...)

Assemble feature matrix and target vector for given data and basis.
`kwargs` are used to control `feature_matrix`, `target_vector` and
`weight_vector` calculations.
"""
function assemble(data::AbstractVector{<:AbstractData}, basis)
@info "Assembling linear problem."
rows = Array{UnitRange}(undef, length(data)) # row ranges for each element of data
rows[1] = 1:count_observations(data[1])
for i in 2:length(data)
rows[i] = rows[i - 1][end] .+ (1:count_observations(data[i]))
end
packets = DataPacket.(rows, data)
sort!(packets, by = length, rev = true)
(nprocs() > 1) && sendto(workers(), basis = basis)
@info " - Creating feature matrix with size ($(rows[end][end]), $(length(basis)))."
A = SharedArray(zeros(rows[end][end], length(basis)))
Y = SharedArray(zeros(size(A, 1)))
@info " - Beginning assembly with processor count: $(nprocs())."
@showprogress pmap(packets) do p
A[p.rows, :] .= feature_matrix(p.data, basis)
Y[p.rows] .= target_vector(p.data)
GC.gc()
function assemble(data::AbstractArray, basis; batch_size=1, kwargs...)
W = Threads.@spawn ACEfit.assemble_weights(data; kwargs...)
raw_data = @showprogress desc="Assembly progress:" pmap( data; batch_size=batch_size ) do d
A = ACEfit.feature_matrix(d, basis; kwargs...)
Y = ACEfit.target_vector(d; kwargs...)
(A, Y)
end
@info " - Assembly completed."
return Array(A), Array(Y), assemble_weights(data)
A = [ a[1] for a in raw_data ]
Y = [ a[2] for a in raw_data ]

A_final = reduce(vcat, A)
Y_final = reduce(vcat, Y)
return A_final, Y_final, fetch(W)
end

"""
assemble_weights(data::AbstractArray; kwargs...)

Assemble full weight vector for vector of data elements.
`kwargs` are used to give extra commands for `weight_vector calculation`.
"""
function assemble_weights(data::AbstractVector{<:AbstractData})
@info "Assembling full weight vector."
rows = Array{UnitRange}(undef, length(data)) # row ranges for each element of data
rows[1] = 1:count_observations(data[1])
for i in 2:length(data)
rows[i] = rows[i - 1][end] .+ (1:count_observations(data[i]))
function assemble_weights(data::AbstractArray; kwargs...)
w = map( data ) do d
ACEfit.weight_vector(d; kwargs...)
end
packets = DataPacket.(rows, data)
sort!(packets, by = length, rev = true)
W = SharedArray(zeros(rows[end][end]))
@showprogress pmap(packets) do p
W[p.rows] .= weight_vector(p.data)
end
return Array(W)
end
return reduce(vcat, w)
end