Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 33 additions & 39 deletions src/Loess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ export loess, predict
include("kd.jl")


mutable struct LoessModel{T <: AbstractFloat}
xs::AbstractMatrix{T} # An n by m predictor matrix containing n observations from m predictors
ys::AbstractVector{T} # A length n response vector
struct LoessModel{T <: AbstractFloat, N <: KDNode}
xs::Matrix{T} # An n by m predictor matrix containing n observations from m predictors
ys::Vector{T} # A length n response vector
predictions_and_gradients::Dict{Vector{T}, Vector{T}} # kd-tree vertexes mapped to prediction and gradient at each vertex
kdtree::KDTree{T}
kdtree::KDTree{T, N}
end

"""
Expand Down Expand Up @@ -80,8 +80,12 @@ function loess(
end

# distance to each point
for i in 1:n
ds[i] = euclidean(vec(vert), vec(xs[i,:]))
@inbounds for i in 1:n
s = zero(T)
for j in 1:m
s += (xs[i, j] - vert[j])^2
end
ds[i] = sqrt(s)
end

# find the q closest points
Expand Down Expand Up @@ -128,7 +132,7 @@ function loess(
]
end

LoessModel{T}(xs, ys, predictions_and_gradients, kdtree)
LoessModel(xs, ys, predictions_and_gradients, kdtree)
end

loess(xs::AbstractVector{T}, ys::AbstractVector{T}; kwargs...) where {T<:AbstractFloat} =
Expand All @@ -153,49 +157,39 @@ end
# Returns:
# A length n' vector of predicted response values.
#
function predict(model::LoessModel, z::Real)
predict(model, [z])
end

function predict(model::LoessModel, zs::AbstractVector)

Base.require_one_based_indexing(zs)

m = size(model.xs, 2)
function predict(model::LoessModel{T}, z::Number) where T
adjacent_verts = traverse(model.kdtree, (T(z),))

# in the univariate case, interpret a non-singleton zs as vector of
# ponits, not one point
if m == 1 && length(zs) > 1
return predict(model, reshape(zs, (length(zs), 1)))
end
@assert(length(adjacent_verts) == 2)
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]

if length(zs) != m
error("$(m)-dimensional model applied to length $(length(zs)) vector")
if z == v₁ || z == v₂
return first(model.predictions_and_gradients[[z]])
end

adjacent_verts = traverse(model.kdtree, zs)
y₁, dy₁ = model.predictions_and_gradients[[v₁]]
y₂, dy₂ = model.predictions_and_gradients[[v₂]]

if m == 1
@assert(length(adjacent_verts) == 2)
z = zs[1]
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]
b_int = cubic_interpolation(v₁, y₁, dy₁, v₂, y₂, dy₂)

if z == v₁ || z == v₂
return first(model.predictions_and_gradients[[z]])
end
return evalpoly(z, b_int)
end

y₁, dy₁ = model.predictions_and_gradients[[v₁]]
y₂, dy₂ = model.predictions_and_gradients[[v₂]]
function predict(model::LoessModel, zs::AbstractVector)
if size(model.xs, 2) > 1
throw(ArgumentError("Multivariate blending not yet implemented"))
end

b_int = cubic_interpolation(v₁, y₁, dy₁, v₂, y₂, dy₂)
predict.(Ref(model), zs)
end

return evalpoly(z, b_int)
else
error("Multivariate blending not yet implemented")
function predict(model::LoessModel, zs::AbstractMatrix)
if size(model.xs, 2) > 1
throw(ArgumentError("Multivariate blending not yet implemented"))
end
end

predict(model::LoessModel, zs::AbstractMatrix) = map(Base.Fix1(predict, model), eachrow(zs))
return [predict(model, z) for z in vec(zs)]
end

"""
tricubic(u)
Expand Down
45 changes: 25 additions & 20 deletions src/kd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ abstract type KDNode end
struct KDLeafNode <: KDNode
end

struct KDInternalNode{T <: AbstractFloat} <: KDNode
struct KDInternalNode{T <: AbstractFloat, LN <: KDNode, RN <: KDNode} <: KDNode
j::Int # dimension on which the data is split
med::T # median value where the split occours
leftnode::KDNode
rightnode::KDNode
leftnode::LN
rightnode::RN
end


struct KDTree{T <: AbstractFloat}
xs::AbstractMatrix{T} # A matrix of n, m-dimensional observations
struct KDTree{T <: AbstractFloat, N <: KDNode}
xs::Matrix{T} # A matrix of n, m-dimensional observations
perm::Vector{Int} # permutation of data to avoid modifying xs
root::KDNode # root node
root::N # root node
verts::Set{Vector{T}}
bounds::Matrix{T} # Top-level bounding box
end
Expand Down Expand Up @@ -226,7 +226,7 @@ function build_kdtree(xs::AbstractMatrix{T},
push!(verts, T[vert...])
end

KDInternalNode{T}(j, med, leftnode, rightnode)
KDInternalNode(j, med, leftnode, rightnode)
end


Expand All @@ -246,15 +246,17 @@ end
Traverse the tree `kdtree` to the bottom and return the verticies of
the bounding hypercube of the leaf node containing the point `x`.
"""
function traverse(kdtree::KDTree, x::AbstractVector)
function traverse(kdtree::KDTree{T}, x::NTuple{N,T}) where {N,T}

m = size(kdtree.bounds, 2)

if length(x) != m
if N != m
throw(DimensionMismatch("$(m)-dimensional kd-tree searched with a length $(length(x)) vector."))
end

for j in 1:m
for j in 1:N
if x[j] < kdtree.bounds[1, j] || x[j] > kdtree.bounds[2, j]
@show x, kdtree.bounds
error(
"""
Loess cannot perform extrapolation. Predict can only be applied
Expand All @@ -266,15 +268,18 @@ function traverse(kdtree::KDTree, x::AbstractVector)

bounds = copy(kdtree.bounds)
node = kdtree.root
while !isa(node, KDLeafNode)
if x[node.j] <= node.med
bounds[2, node.j] = node.med
node = node.leftnode
else
bounds[1, node.j] = node.med
node = node.rightnode
end
end

bounds_verts(bounds)
return _traverse!(bounds, node, x)
end

_traverse!(bounds, node::KDLeafNode, x) = bounds
function _traverse!(bounds, node::KDInternalNode, x)
if x[node.j] <= node.med
bounds[2, node.j] = node.med
return _traverse!(bounds, node.leftnode, x)
else
bounds[1, node.j] = node.med
return _traverse!(bounds, node.rightnode, x)
end
end