Skip to content

Commit 1b90daa

Browse files
committed
Use cubic instead of linear interpolation in predict
The cubic interpolation is what is suggested in Cleveland and Grosee (1991) and also what R uses. It makes the prediction function once differentiable which is what I think most people would expect from LOESS. Also, fix bug in use of partialsort!. Only the q'th element was ensured to be at the right localtion instead of all the first q elements.
1 parent 60a5998 commit 1b90daa

File tree

2 files changed

+57
-42
lines changed

2 files changed

+57
-42
lines changed

src/Loess.jl

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ include("kd.jl")
1313
mutable struct LoessModel{T <: AbstractFloat}
1414
xs::AbstractMatrix{T} # An n by m predictor matrix containing n observations from m predictors
1515
ys::AbstractVector{T} # A length n response vector
16-
bs::Matrix{T} # Least squares coefficients
17-
verts::Dict{Vector{T}, Int} # kd-tree vertexes mapped to indexes
16+
predictions_and_gradients::Dict{Vector{T}, Vector{T}} # kd-tree vertexes mapped to prediction and gradient at each vertex
1817
kdtree::KDTree{T}
1918
end
2019

@@ -59,27 +58,25 @@ function loess(
5958
# correctly apply predict to unnormalized data. We should have a normalize
6059
# function that just returns a vector of scaling factors.
6160
if normalize && m > 1
61+
throw(ArgumentError("higher dimensional models not yet supported"))
6262
xs = tnormalize!(copy(xs))
6363
end
6464

6565
kdtree = KDTree(xs, cell * span, 0)
6666

6767
# map verticies to their index in the bs coefficient matrix
68-
verts = Dict{Vector{T}, Int}()
69-
for (k, vert) in enumerate(kdtree.verts)
70-
verts[vert] = k
71-
end
68+
predictions_and_gradients = Dict{Vector{T}, Vector{T}}()
7269

7370
# Fit each vertex
7471
ds = Array{T}(undef, n) # distances
7572
perm = collect(1:n)
76-
bs = Array{T}(undef, length(kdtree.verts), 1 + degree * m)
7773

7874
# TODO: higher degree fitting
7975
us = Array{T}(undef, q, 1 + degree * m)
76+
du1dt = zeros(T, m, 1 + degree * m)
8077
vs = Array{T}(undef, q)
8178

82-
for (vert, k) in verts
79+
for (_, vert) in enumerate(kdtree.verts)
8380
# reset perm
8481
for i in 1:n
8582
perm[i] = i
@@ -109,15 +106,31 @@ function loess(
109106
vs[i] = ys[pᵢ] * w
110107
end
111108

109+
# Compute the gradient of the vertex
110+
pᵢ = perm[1]
111+
for j in 1:m
112+
x = xs[pᵢ, j]
113+
xl = one(x)
114+
for l in 1:degree
115+
du1dt[j, 1 + (j - 1)*degree + l] = l * xl
116+
xl *= x
117+
end
118+
end
119+
112120
if VERSION < v"1.7.0-DEV.1188"
113121
F = qr(us, Val(true))
114122
else
115123
F = qr(us, ColumnNorm())
116124
end
117-
bs[k,:] = F\vs
125+
coefs = F\vs
126+
127+
predictions_and_gradients[vert] = [
128+
us[1, :]' * coefs; # the prediction
129+
du1dt * coefs # the gradient of the prediction
130+
]
118131
end
119132

120-
LoessModel{T}(xs, ys, bs, verts, kdtree)
133+
LoessModel{T}(xs, ys, predictions_and_gradients, kdtree)
121134
end
122135

123136
loess(xs::AbstractVector{T}, ys::AbstractVector{T}; kwargs...) where {T<:AbstractFloat} =
@@ -170,23 +183,22 @@ function predict(model::LoessModel, zs::AbstractVector)
170183
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]
171184

172185
if z == v₁ || z == v₂
173-
return evalpoly(zs, model.bs[model.verts[[z]],:])
186+
return first(model.predictions_and_gradients[[z]])
174187
end
175188

176-
u = (z - v₁)/(v₂ - v₁)
189+
y₁ = model.predictions_and_gradients[[v₁]][1]
190+
dy₁ = model.predictions_and_gradients[[v₁]][2]
191+
y₂ = model.predictions_and_gradients[[v₂]][1]
192+
dy₂ = model.predictions_and_gradients[[v₂]][2]
193+
194+
b_int = cubic_interpolation(v₁, y₁, dy₁, v₂, y₂, dy₂)
177195

178-
y1 = evalpoly(zs, model.bs[model.verts[[v₁]],:])
179-
y2 = evalpoly(zs, model.bs[model.verts[[v₂]],:])
180-
return (1.0 - u) * y1 + u * y2
196+
return Base.Math.@horner(z, b_int...)
181197
else
182198
error("Multivariate blending not yet implemented")
183-
# TODO:
184-
# 1. Univariate linear interpolation between adjacent verticies.
185-
# 2. Blend these estimates. (I'm not sure how this is done.)
186199
end
187200
end
188201

189-
190202
predict(model::LoessModel, zs::AbstractMatrix) = map(Base.Fix1(predict, model), eachrow(zs))
191203

192204
"""
@@ -203,30 +215,33 @@ Returns:
203215
"""
204216
tricubic(u) = (1 - u^3)^3
205217

206-
207218
"""
208-
evalpoly(xs,bs)
209-
210-
Evaluate a multivariate polynomial with coefficients `bs` at `xs`. `bs` should be of length
211-
`1+length(xs)*d` where `d` is the degree of the polynomial.
212-
213-
bs[1] + xs[1]*bs[2] + xs[1]^2*bs[3] + ... + xs[end]^d*bs[end]
214-
219+
cubic_interpolation(x₁, y₁, dy₁, x₂, y₂, dy₂)
220+
221+
Compute the coefficients of the cubic polynomial ``f`` for which
222+
```math
223+
\begin{aligned}
224+
y₁ &= f(x₁) \\
225+
dy₁ &= f'(x₁) \\
226+
y₂ &= f'(x₂) \\
227+
dy₂ &= f'(x₂) \\
228+
\end{aligned}
229+
```
215230
"""
216-
function evalpoly(xs, bs)
217-
m = length(xs)
218-
degree = div(length(bs) - 1, m)
219-
y = bs[1]
220-
for i in 1:m
221-
x = xs[i]
222-
xx = x
223-
y += xx * bs[1 + (i-1)*degree + 1]
224-
for l in 2:degree
225-
xx *= x
226-
y += xx * bs[1 + (i-1)*degree + l]
227-
end
228-
end
229-
y
231+
function cubic_interpolation(x₁, y₁, dy₁, x₂, y₂, dy₂)
232+
Δx = x₁ - x₂
233+
Δx³ = Δx^3
234+
Δy = y₁ - y₂
235+
num0 = -x₂ * (x₁ * Δx * (dy₂ * x₁ + dy₁ * x₂) + x₂ * (x₂ - 3 * x₁) * y₁) + x₁^2 * (x₁ - 3 * x₂) * y₂
236+
num1 = dy₂ * x₁ * Δx * (x₁ + 2 * x₂) - x₂ * (dy₁ * (x₁ * x₂ + x₂^2 - 2 * x₁^2) + 6 * x₁ * Δy)
237+
num2 = -(dy₁ * Δx * (x₁ + 2 * x₂)) + dy₂ * (x₁ * x₂ + x₂^2 - 2 * x₁^2) + 3 * (x₁ + x₂) * Δy
238+
num3 = (dy₁ + dy₂) * Δx - 2 * Δy
239+
return (
240+
num0 / Δx³,
241+
num1 / Δx³,
242+
num2 / Δx³,
243+
num3 / Δx³
244+
)
230245
end
231246

232247
"""

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565

6666
# Test values from R's loess expect outer vertices as they are made wider in the R/C/Fortran implementation
6767
@testset "vertices" begin
68-
@test sort(getindex.(keys(ft.verts))) == [4.0, 8.0, 10.0, 12.0, 13.0, 14.0, 15.0, 17.0, 19.0, 22.0, 25.0]
68+
@test sort(getindex.(keys(ft.predictions_and_gradients))) == [4.0, 8.0, 10.0, 12.0, 13.0, 14.0, 15.0, 17.0, 19.0, 22.0, 25.0]
6969
end
7070

7171
@testset "predict" begin

0 commit comments

Comments
 (0)