Skip to content

Commit 49fe426

Browse files
authored
Merge pull request #63 from JuliaStats/an/cubic
Use cubic instead of linear interpolation in predict
2 parents 0d63dbc + 1e4e6ac commit 49fe426

4 files changed

Lines changed: 70 additions & 49 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
version:
18-
- '1.3' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
18+
- '1.6' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
1919
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
2020
- 'nightly'
2121
os:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
1111
[compat]
1212
Distances = "0.7, 0.8, 0.9, 0.10"
1313
StatsAPI = "1.1"
14-
julia = "1.3"
14+
julia = "1.6"
1515

1616
[extras]
1717
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"

src/Loess.jl

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ include("kd.jl")
1313
mutable struct LoessModel{T <: AbstractFloat}
1414
xs::AbstractMatrix{T} # An n by m predictor matrix containing n observations from m predictors
1515
ys::AbstractVector{T} # A length n response vector
16-
bs::Matrix{T} # Least squares coefficients
17-
verts::Dict{Vector{T}, Int} # kd-tree vertexes mapped to indexes
16+
predictions_and_gradients::Dict{Vector{T}, Vector{T}} # kd-tree vertexes mapped to prediction and gradient at each vertex
1817
kdtree::KDTree{T}
1918
end
2019

@@ -59,27 +58,25 @@ function loess(
5958
# correctly apply predict to unnormalized data. We should have a normalize
6059
# function that just returns a vector of scaling factors.
6160
if normalize && m > 1
61+
throw(ArgumentError("higher dimensional models not yet supported"))
6262
xs = tnormalize!(copy(xs))
6363
end
6464

6565
kdtree = KDTree(xs, cell * span, 0)
6666

67-
# map verticies to their index in the bs coefficient matrix
68-
verts = Dict{Vector{T}, Int}()
69-
for (k, vert) in enumerate(kdtree.verts)
70-
verts[vert] = k
71-
end
67+
# map verticies to their prediction and prediction gradient
68+
predictions_and_gradients = Dict{Vector{T}, Vector{T}}()
7269

7370
# Fit each vertex
7471
ds = Array{T}(undef, n) # distances
7572
perm = collect(1:n)
76-
bs = Array{T}(undef, length(kdtree.verts), 1 + degree * m)
7773

78-
# TODO: higher degree fitting
74+
# Initialize the regression arrays
7975
us = Array{T}(undef, q, 1 + degree * m)
76+
du1dt = zeros(T, m, 1 + degree * m)
8077
vs = Array{T}(undef, q)
8178

82-
for (vert, k) in verts
79+
for vert in kdtree.verts
8380
# reset perm
8481
for i in 1:n
8582
perm[i] = i
@@ -109,15 +106,31 @@ function loess(
109106
vs[i] = ys[pᵢ] * w
110107
end
111108

109+
# Compute the gradient of the vertex
110+
pᵢ = perm[1]
111+
for j in 1:m
112+
x = xs[pᵢ, j]
113+
xl = one(x)
114+
for l in 1:degree
115+
du1dt[j, 1 + (j - 1)*degree + l] = l * xl
116+
xl *= x
117+
end
118+
end
119+
112120
if VERSION < v"1.7.0-DEV.1188"
113121
F = qr(us, Val(true))
114122
else
115123
F = qr(us, ColumnNorm())
116124
end
117-
bs[k,:] = F\vs
125+
coefs = F\vs
126+
127+
predictions_and_gradients[vert] = [
128+
us[1, :]' * coefs; # the prediction
129+
du1dt * coefs # the gradient of the prediction
130+
]
118131
end
119132

120-
LoessModel{T}(xs, ys, bs, verts, kdtree)
133+
LoessModel{T}(xs, ys, predictions_and_gradients, kdtree)
121134
end
122135

123136
loess(xs::AbstractVector{T}, ys::AbstractVector{T}; kwargs...) where {T<:AbstractFloat} =
@@ -170,23 +183,20 @@ function predict(model::LoessModel, zs::AbstractVector)
170183
v₁, v₂ = adjacent_verts[1][1], adjacent_verts[2][1]
171184

172185
if z == v₁ || z == v₂
173-
return evalpoly(zs, model.bs[model.verts[[z]],:])
186+
return first(model.predictions_and_gradients[[z]])
174187
end
175188

176-
u = (z - v₁)/(v₂ - v₁)
189+
y₁, dy₁ = model.predictions_and_gradients[[v₁]]
190+
y₂, dy₂ = model.predictions_and_gradients[[v₂]]
191+
192+
b_int = cubic_interpolation(v₁, y₁, dy₁, v₂, y₂, dy₂)
177193

178-
y1 = evalpoly(zs, model.bs[model.verts[[v₁]],:])
179-
y2 = evalpoly(zs, model.bs[model.verts[[v₂]],:])
180-
return (1.0 - u) * y1 + u * y2
194+
return evalpoly(z, b_int)
181195
else
182196
error("Multivariate blending not yet implemented")
183-
# TODO:
184-
# 1. Univariate linear interpolation between adjacent verticies.
185-
# 2. Blend these estimates. (I'm not sure how this is done.)
186197
end
187198
end
188199

189-
190200
predict(model::LoessModel, zs::AbstractMatrix) = map(Base.Fix1(predict, model), eachrow(zs))
191201

192202
"""
@@ -203,30 +213,33 @@ Returns:
203213
"""
204214
tricubic(u) = (1 - u^3)^3
205215

206-
207216
"""
208-
evalpoly(xs,bs)
209-
210-
Evaluate a multivariate polynomial with coefficients `bs` at `xs`. `bs` should be of length
211-
`1+length(xs)*d` where `d` is the degree of the polynomial.
212-
213-
bs[1] + xs[1]*bs[2] + xs[1]^2*bs[3] + ... + xs[end]^d*bs[end]
214-
217+
cubic_interpolation(x₁, y₁, dy₁, x₂, y₂, dy₂)
218+
219+
Compute the coefficients of the cubic polynomial ``f`` for which
220+
```math
221+
\begin{aligned}
222+
y₁ &= f(x₁) \\
223+
dy₁ &= f'(x₁) \\
224+
y₂ &= f(x₂) \\
225+
dy₂ &= f'(x₂) \\
226+
\end{aligned}
227+
```
215228
"""
216-
function evalpoly(xs, bs)
217-
m = length(xs)
218-
degree = div(length(bs) - 1, m)
219-
y = bs[1]
220-
for i in 1:m
221-
x = xs[i]
222-
xx = x
223-
y += xx * bs[1 + (i-1)*degree + 1]
224-
for l in 2:degree
225-
xx *= x
226-
y += xx * bs[1 + (i-1)*degree + l]
227-
end
228-
end
229-
y
229+
function cubic_interpolation(x₁, y₁, dy₁, x₂, y₂, dy₂)
230+
Δx = x₁ - x₂
231+
Δx³ = Δx^3
232+
Δy = y₁ - y₂
233+
num0 = -x₂ * (x₁ * Δx * (dy₂ * x₁ + dy₁ * x₂) + x₂ * (x₂ - 3 * x₁) * y₁) + x₁^2 * (x₁ - 3 * x₂) * y₂
234+
num1 = dy₂ * x₁ * Δx * (x₁ + 2 * x₂) - x₂ * (dy₁ * (x₁ * x₂ + x₂^2 - 2 * x₁^2) + 6 * x₁ * Δy)
235+
num2 = -(dy₁ * Δx * (x₁ + 2 * x₂)) + dy₂ * (x₁ * x₂ + x₂^2 - 2 * x₁^2) + 3 * (x₁ + x₂) * Δy
236+
num3 = (dy₁ + dy₂) * Δx - 2 * Δy
237+
return (
238+
num0 / Δx³,
239+
num1 / Δx³,
240+
num2 / Δx³,
241+
num3 / Δx³
242+
)
230243
end
231244

232245
"""

test/runtests.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,22 @@ end
6565

6666
# Test values from R's loess expect outer vertices as they are made wider in the R/C/Fortran implementation
6767
@testset "vertices" begin
68-
@test sort(getindex.(keys(ft.verts))) == [4.0, 8.0, 10.0, 12.0, 13.0, 14.0, 15.0, 17.0, 19.0, 22.0, 25.0]
68+
@test sort(getindex.(keys(ft.predictions_and_gradients))) == [4.0, 8.0, 10.0, 12.0, 13.0, 14.0, 15.0, 17.0, 19.0, 22.0, 25.0]
6969
end
7070

7171
@testset "predict" begin
7272
# In R this is `predict(cars.lo, data.frame(speed = seq(5, 25, 1)))`.
7373
Rvals = [7.797353, 10.002308, 12.499786, 15.281082, 18.446568, 21.865315, 25.517015, 29.350386, 33.230660, 37.167935, 41.205226, 45.055736, 48.355889, 49.824812, 51.986702, 56.461318, 61.959729, 68.569313, 76.316068, 85.212121, 95.324047]
74-
@test predict(ft, [10, 15, 22]) Rvals[[6, 11, 18]] rtol=1e-5
75-
# The interpolated values are broken until https://github.com/JuliaStats/Loess.jl/pull/63 is merged
76-
@test_broken predict(ft) Rvals rtol=1e-5
74+
75+
for (x, Ry) in zip(5:25, Rvals)
76+
if 8 <= x <= 22
77+
@test predict(ft, x) Ry rtol = 1e-7
78+
else
79+
# The outer vertices are expanded by 0.105 in the original implementation. Not sure if we
80+
# want to do the same thing so meanwhile the results will deviate slightly between the
81+
# outermost vertices
82+
@test predict(ft, x) Ry rtol = 1e-3
83+
end
84+
end
7785
end
7886
end

0 commit comments

Comments
 (0)